#!/usr/bin/python

import sys
import time
import optparse
from retrying import retry
import boto.ec2.autoscale
import boto.ec2




class Business(object):

    TAG = "marked_for_termination"

    def __init__(self, region, profile, asg):
        """
        Create connections to the AWS services required
        """
        self.region = region
        self.profile = profile
        self.asg = asg
        self.autoscale = boto.ec2.autoscale.connect_to_region(self.region, profile_name=self.profile)
        self.ec2 = boto.ec2.connect_to_region(self.region, profile_name=self.profile)
        self.elb = boto.ec2.elb.connect_to_region(self.region, profile_name=self.profile)

    def _print(self, msg):
        """
        Print to screen
        """
        print "<--- {} --->".format(msg)

    @retry
    def _list_instances_in_asg(self):
        """
        Query AWS using the autoscaling group name and then return all
        instances in that group.

        :return: array of instances
        """
        groups = self.autoscale.get_all_groups([self.asg])
        if len(groups) < 1:
            print "Error locating ASG group"
            ## Exit here?
            pass
        return self.ec2.get_all_instances([i.instance_id for i in groups[0].instances])

    @retry
    def _list_instance_ids_in_asg(self):
        """
        List all instancess currently in the ELB
        
        :return: array of instance ids
        """
        return [i.instances[0].id for i in self._list_instances_in_asg()]


    @retry
    def _get_desired_capacity(self):
        """
        Query the ASG for the desired instances capacity

        :return: int no of instances
        """
        groups = self.autoscale.get_all_groups([self.asg])[0]
        return int(groups.desired_capacity)


    @retry
    def _get_current_capacity(self):
        """
        Should probs use 
        except boto.exception.BotoServerError as e:
        to catch instances that are removed from the ELB
        """
        val = 0
        #print self.__list_instances_in_asg()
        for lb in self.autoscale.get_all_groups([self.asg])[0].load_balancers:
            for i in self.elb.describe_instance_health(lb, self._list_instance_ids_in_asg()):
                if i.state == 'InService':
                    val += 1
            ## This is really dodge, currently it returns the capacity of the first
            ## ELB, but it should really check all ELBs
            return val
        #return val
 


    def _mark_instances_for_termination(self):
        """
        Tag instances that are to be terminated and removed
        from the ELB
        """
        self._print("Marking instances for Termination")
        for i in self._list_instances_in_asg():
            try:
                self._print("Tagging " + i.instances[0].id)
                i.instances[0].add_tag(self.TAG, "true")
            except:
                self._print("Tagging failed")

    @retry
    def _check_instance_health(self):
        #print self.__list_instances_in_asg()
        for lb in self.autoscale.get_all_groups()[0].load_balancers:
            for i in self.elb.describe_instance_health(lb, self._list_instance_ids_in_asg()):
                self._print(i.state)
            
    @retry
    def _get_remaining_instances(self):
        inst = []
        for instance in self._list_instances_in_asg():
            if self.TAG in instance.instances[0].tags:
                if instance.instances[0].tags[self.TAG] == 'true':
                    inst.append(instance)
        return inst


    def terminate(self):
        """
        This is the where the interesting stuff happens
        """

        self._mark_instances_for_termination()
        instances = self._get_remaining_instances()


        self._print("Current capacity [ " + str(self._get_current_capacity()) + " ]  --  Desired capacity [ " + str(self._get_desired_capacity()) + " ]")

        ## Should check if ASG health_type == "elb"
        for instance in list(instances):

            msg = False
            while self._get_current_capacity() < self._get_desired_capacity():
                if msg == False:
                    self._print("Waiting on additional capacity...")
                    msg = True
                time.sleep(10)
            self._print("Terminating instance: " + instance.instances[0].id)
            instance.instances[0].terminate()
            instances.remove(instance)
            time.sleep(10)

        self._print("Waiting for Final Instances")
        while self._get_current_capacity() < self._get_desired_capacity():
            pass

        self._print("Finished")


def main():

    parser = optparse.OptionParser(
            formatter=optparse.TitledHelpFormatter(width=78),
            add_help_option=None)

    parser.add_option("-a", "--asg", dest="asg", help="Name of the Autoscaling Group")
    parser.add_option("-r", "--region", dest="region", help="Region of the Autoscaling Group")
    parser.add_option("-p", "--profile", dest="profile", help="IAM credentials profile")

    options, args = parser.parse_args(sys.argv[1:])

    if not options.asg or not options.region or not options.profile:
        print "Wrong, fail, error, etc..."
        sys.exit(1)

    Business(options.region, options.profile, options.asg).terminate()


if __name__ == '__main__':
    main()


