#!/usr/bin/env python3
# pylint: disable=logging-fstring-interpolation

import argparse
import subprocess
import json
import sys
import logging
import ast
import itertools
from copy import deepcopy

from rich.logging import RichHandler
from rich.table import Table
from rich.progress import Progress
from rich import print as rprint

from gevo import __version__
from gevo.evolve import evolution
from gevo.irind import edits_as_key

logging.basicConfig(format="%(message)s" ,level="NOTSET" ,handlers=[RichHandler()])
log=logging.getLogger("main")

class program(evolution):
    def __init__(self, editf, kernel, bin, profile, timeout=30, fitness='time',
                 llvm_src_filename='cuda-device-only-kernel.ll', err_rate=0.01):
        super().__init__(
            kernel=kernel,
            bin=bin,
            profile=profile,
            timeout=timeout,
            fitness=fitness,
            err_rate=err_rate,
            mutop='',
            use_fitness_map=False )

        try:
            with open(editf, 'r') as f:
                self.edits = ast.literal_eval(f.read())
        except FileNotFoundError:
            log.error(f"Edit File:{editf} cannot be found")
            sys.exit(1)

        rprint("Evaluate edit file ", end="", flush=True)
        self.fullEditsInd = self.toolbox.individual()
        self.fullEditsInd.edits = self.edits
        if self.fullEditsInd.update_from_edits() is False:
            raise Exception("Edit file cannot be compiled")
        fitness_values = [self.evaluate(self.fullEditsInd) for i in range(3)]
        if None in [value[0] for value in fitness_values]:
            raise Exception("Edit file fails the verification")
        fit = min([value[0] for value in fitness_values])
        err = min([value[1] for value in fitness_values])
        self.fullEditsInd.fitness.values = (fit, err)
        rprint("")
        log.info(f"Fitness of the program with all edits: {self.fullEditsInd.fitness}")    

    def remove_weak_edits(self, edits, threshold=0.01):
        '''
        :param edits: removing weak edit from edits
        :returns: The useful edits
        '''
        log.info("Start removing weak edits ...")
        with Progress(auto_refresh=False) as pbar:
            task1 = pbar.add_task("", total=len(edits))
            pbar.update(task1, completed=0, refresh=True,
                        description=f"(0/{len(edits)})")

            removal_list = []
            indPrior = deepcopy(self.fullEditsInd)
            for cnt, edit in enumerate(edits):
                indPriorwoEdit = self.toolbox.individual()

                indPriorwoEdit.edits = deepcopy(indPrior.edits)
                indPriorwoEdit.edits.remove(edit)

                if indPriorwoEdit.update_from_edits() == False:
                    log.info(f"{edit[0]} not removed: cannot compile")
                    continue

                fitness_values = [self.evaluate(indPriorwoEdit) for i in range(3)]
                if None in [value[0] for value in fitness_values]:
                    log.info(f"{edit[0]} not removed: execution failed")
                    continue

                fit = min([value[0] for value in fitness_values])
                err = max([value[1] for value in fitness_values])
                indPriorwoEdit.fitness.values = (fit, err)
                improvement = indPrior.fitness.values[0]/fit
                if improvement > 1-threshold and abs(err - indPrior.fitness.values[1]) < 0.01*threshold:
                    removal_list.append(edit)
                    log.info(f"{edit[0]} removed: {fit:.2f}. Improvement: {improvement:.2f}. Error:{err}")
                    indPrior = indPriorwoEdit
                else:
                    log.info(f"{edit[0]} not removed: {fit:.2f}. Improvement: {improvement:.2f}. Error:{err}")

                pbar.update(task1, completed=cnt+1, refresh=True,
                            description=f"({cnt+1}/{len(edits)})")
        
        with open("reduced.edit", "w") as f:
            rprint(indPrior.edits, file=f)
        with open("reduced.ll", "w") as f:
            f.write(indPrior.srcEnc.decode())
        log.info("Done writing reduced.edit and reduced.ll")
        log.info(f"Fitness of the edit-reduced program: {indPrior.fitness}")
        return indPrior.edits
    
    def edittest(self, edits):
        pop = self.toolbox.population(n=len(edits))
        fitness = [None] * len(edits)
        for ind,edit,fits in zip(pop, self.edits, fitness):
            ind.edits = [edit]
            if ind.update_from_edits() == False:
                rprint(f"{edit}: cannot compile")
                continue
            fits = [self.evaluate(ind)[0] for i in range(3)]
            errs = [self.evaluate(ind)[1] for i in range(3)]
            if None in fits:
                rprint(f"{edit}: execution failed")
                continue
            if None in errs:
                rprint(f"{edit}: execution failed")
                continue
            fit = min(fits)
            err = max(errs)
            improvement = self.origin.fitness.values[0]/fit
            rprint(f"{edit}: {fit:.2f}. Improvement: {improvement:.2f}. Error:{err:.2f}")

    def search_indepedent_edits(self, edits):
        '''
        :param edits: input edits that will be divided into independent or epistasis group
        :returns: independent edits and epistasis edits
        '''
        log.info("Start searching for indepedent/epistasis edits ...")
        independentEdits = []
        for edit in edits:
            editOnlyInd = self.toolbox.individual()
            # editOnlyInd.edits = [edit]
            editOnlyInd.edits = deepcopy(independentEdits)
            editOnlyInd.edits.append(edit)
            if editOnlyInd.update_from_edits() is False:
                continue
            fitness_values = [self.evaluate(editOnlyInd) for i in range(3)]
            if None in [value[0] for value in fitness_values]:
                continue
            fit = min([value[0] for value in fitness_values])
            err = min([value[1] for value in fitness_values])
            editOnlyInd.fitness.values = (fit, err)
            runtimeDiffDown = self.origin.fitness.values[0] - editOnlyInd.fitness.values[0]

            fullExceptEditInd = self.toolbox.individual()
            fullExceptEditInd.edits = [ e for e in edits if e not in editOnlyInd.edits]
            if fullExceptEditInd.update_from_edits() is False:
                continue
            fitness_values = [self.evaluate(fullExceptEditInd) for i in range(3)]
            if None in [value[0] for value in fitness_values]:
                continue
            fit = min([value[0] for value in fitness_values])
            err = min([value[1] for value in fitness_values])
            fullExceptEditInd.fitness.values = (fit, err)
            runtimeDiffTop = fullExceptEditInd.fitness.values[0] - self.fullEditsInd.fitness.values[0]
            
            if abs(runtimeDiffDown - runtimeDiffTop) < self.origin.fitness.values[0]*0.01:
                log.info(f"{edit} can be independently applied: {editOnlyInd.fitness.values[0]:.2f}.")
                independentEdits.append(edit)

        epistasis = [ edit for edit in edits if edit not in independentEdits ]
        with open("reduced_no_independent.edit", 'w') as f:
            rprint(epistasis, file=f)
        with open("reduced_independent.edit", 'w') as f:
            rprint(independentEdits, file=f)
        return independentEdits, epistasis

    def group_test(self, edits):
        log.info("Start evaluating all edit combinations iteratively ...")
        editIdxMap = {edits_as_key([edit]): cnt for cnt, edit in enumerate(edits)}
        
        grid = Table.grid()
        grid.add_column(justify="right", style="bold blue")
        grid.add_column()
        for cnt, edit in enumerate(edits):
            grid.add_row(str(cnt)+': ', str(edit))
        rprint(grid)
        
        fcomb = open("group_test.txt", 'w')
        rprint(grid, file=fcomb)

        for l in range(2, len(edits)):
            fcomb.write(f"{l} combinations\n")
            with Progress(auto_refresh=False) as pbar:
                task1 = pbar.add_task(f"", total=len(list(itertools.combinations(edits, l))))
                cnt = 0
                for subEdits in itertools.combinations(edits, l): 
                    fcomb.flush()
                    pbar.update(task1, completed=cnt, refresh=True,
                                description=f"{l} combination: ({cnt}/{len(list(itertools.combinations(edits, l)))})")
                    cnt = cnt + 1

                    subEditsInd = self.toolbox.individual()
                    subEditsInd.edits = subEdits
                    edit_str = ' '.join([ str(editIdxMap[edits_as_key([edit])]) for edit in subEdits ])
                    if subEditsInd.update_from_edits() is False:
                        fcomb.write(f'{edit_str},c,c,c\n')
                        continue
                    fitness_values = [self.evaluate(subEditsInd) for i in range(3)]
                    if None in [value[0] for value in fitness_values]:
                        fcomb.write(f'{edit_str},x,x,x\n')
                        continue
                    fit = min([value[0] for value in fitness_values])
                    err = min([value[1] for value in fitness_values])
                    subEditsInd.fitness.values = (fit, err)
                    improvement = self.origin.fitness.values[0] / fit
                    if improvement < 1.01:
                        fcomb.write(f'{edit_str},{fit:.2f},{err:.2f},{improvement:.2f}\n')
                        continue

                    rprint(f'[ {edit_str}]: ({fit:.2f}, {err:.2f}), Imp:{improvement:.2f}')
                    fcomb.write(f'{edit_str},{fit:.2f},{err:.2f},{improvement:.2f}\n')

        fcomb.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Analyze the performance of mutation edits for CUDA kernel")
    parser.add_argument('-P', '--profile_file', type=str, required=True,
        help="Specify the profile file that contains all application execution and testing information")
    parser.add_argument('-e', '--edit', type=str, required=True,
        help="The edit file")
    parser.add_argument('-t', '--timeout', type=int, default=30,
        help="The timeout period to evaluate the CUDA application")
    parser.add_argument('-fitf', '--fitness_function', type=str, default='time',
        help="What is the target fitness for the evolution. Default ot execution time. Can be changed to power")
    parser.add_argument('--err_rate', type=str, default='0.01',
        help="Allowed maximum relative error generate from mutant comparing to the origin")
    parser.add_argument('--version', action='version', version='gevo-' + __version__)
    args = parser.parse_args()

    try:
        profile = json.load(open(args.profile_file))
    except FileNotFoundError:
        log.error(f"The profile:'{args.profile_file}' cannot be found")
    except:
        print(sys.exc_info())
        exit(-1)

    alyz = program(
        editf=args.edit,
        kernel=profile['kernels'],
        bin=profile['binary'],
        profile=profile,
        timeout=args.timeout,
        fitness=args.fitness_function,
        err_rate=args.err_rate)

    table = Table.grid(expand=True)
    table.add_column(justify="right", style="bold blue")
    table.add_column()
    table.add_row("Target CUDA program: ", profile['binary'])
    tc_args = ""
    for tc in alyz.testcase:
        tc_args = tc_args + "{}".format(" ".join(tc.args)) + '\n'
    table.add_row("Args for the CUDA program:: ", tc_args)
    table.add_row("Target kernels:: ", " ".join(profile['kernels']))
    table.add_row("Evaluation Timeout:: ", str(args.timeout))
    table.add_row("Fitness function:: ", args.fitness_function)
    table.add_row("Edit file:: ", args.edit)
    table.add_row("Tolerate Error Rate:: ", str(args.err_rate))
    rprint(table)

    try:
        curEdits = alyz.edits
        while True:
            rprint("0) list current edits")
            rprint("1) test each edit individualy")
            rprint("2) remove weak edits")
            rprint("3) group independent/epistasis edits")
            rprint("4) test edit combinations exhaustively")
            op = input("Chose the analysis operation you want to go over: ")

            if op == '0':
                rprint(curEdits)
            elif op == '1':
                alyz.edittest(curEdits)
            elif op == '2':
                curEdits = alyz.remove_weak_edits(curEdits)
            elif op == '3':
                groupOp = input("return (1)independent or (2)epistasis as current edits after grouping?")
                indEdits, epistasisEdits = alyz.search_indepedent_edits(curEdits)
                curEdits = indEdits if groupOp == 1 else epistasisEdits
            elif op == '4':
                alyz.group_test(curEdits)
            else:
                log.warning(f"Invalid selection: {op}")
    except KeyboardInterrupt:
        subprocess.run(['killall', profile['binary']])
