#!/usr/bin/env python3
import argparse
import subprocess
import json
import sys
import os

from deap import creator
from rich.table import Table
from rich import print as rprint

from gevo.evolve import evolution

class program(evolution):
    # Parameters
    cudaPTX = 'gevo.ptx'

    def __init__(self, kernel, bin, profile, timeout=30, fitness='time',
                 llvm_src_filename='cuda-device-only-kernel.ll', err_rate=0.01):
        super().__init__(
            kernel=kernel,
            bin=bin,
            profile=profile,
            timeout=timeout,
            fitness=fitness,
            err_rate=err_rate,
            mutop='',
            use_fitness_map=False )

    def evaluate(self, llvm_file):
        try:
            with open(llvm_file, 'r') as f:
                llvmSrcEnc = f.read().encode()
        except IOError:
            print("File {} does not exist".format(llvm_file))
            exit(1)

        individual = creator.Individual(llvmSrcEnc, self.mgpu)

        # Need to retrive individual test case result which is why the built-in evluation function is not used
        # link
        individual.ptx(self.cudaPTX)

        with open('gevo.ll', 'w') as f:
            f.write(individual.srcEnc.decode())

        fits = []
        errs = []
        for tc in self.testcase:
            fitness, err = self.execNVprofRetrive(tc)
            if fitness is not None:
                fitness2, err2 = self.execNVprofRetrive(tc)
                fitness3, err3 = self.execNVprofRetrive(tc)
                fitness = min([fitness, fitness2, fitness3])
                err = min([err, err2, err3])
            else:
                print("Tested llvm program failed to execute")
                exit(0)

            for res_file in self.verifier['output']:
                if os.path.exists(res_file):
                    os.remove(res_file)

            fits.append(fitness)
            errs.append(err)

        max_err = max(errs)
        avg_fitness = sum(fits)/len(fits)
        # record the edits and the corresponding fitness in the map
        rprint("Average Fitness of tested llvm program: {}, {}, {}".format(avg_fitness, max_err,
            self.origin.fitness.values[0]/avg_fitness*100))
        rprint("Individual test cases:")
        for fit, err, ofit in zip(fits, errs, self.ofits):
            rprint(f"({fit:.2f}, {err:.2f}), Improvement: {ofit/fit*100:.2f}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Analyze the performance of mutation edits for CUDA kernel")
    parser.add_argument('-P', '--profile_file', type=str, required=True,
        help="Specify the profile file that contains all application execution and testing information")
    parser.add_argument('-l', '--llvm_file', type=str, required=True,
        help="The edit file")
    parser.add_argument('-t', '--timeout', type=int, default=30,
        help="The timeout period to evaluate the CUDA application")
    parser.add_argument('-fitf', '--fitness_function', type=str, default='time',
        help="What is the target fitness for the evolution. Default ot execution time. Can be changed to power")
    parser.add_argument('--err_rate', type=float, default='0.01',
        help="Allowed maximum relative error generate from mutant comparing to the origin")
    args = parser.parse_args()

    try:
        profile = json.load(open(args.profile_file))
    except:
        print(sys.exc_info())
        exit(-1)

    alyz = program(
        kernel=profile['kernels'],
        bin=profile['binary'],
        profile=profile,
        timeout=args.timeout,
        fitness=args.fitness_function,
        err_rate=args.err_rate)

    table = Table.grid()
    table.add_column(justify="right", style="bold blue")
    table.add_column()
    table.add_row("Target CUDA program: ", profile['binary'])
    tc_args = ""
    for tc in alyz.testcase:
        tc_args = tc_args + "{}".format(" ".join(tc.args)) + '\n'
    table.add_row("Args for the CUDA program: ", tc_args)
    table.add_row("Target kernels: ", " ".join(profile['kernels']))
    table.add_row("Evaluation Timeout: ", str(args.timeout))
    table.add_row("Fitness function: ", args.fitness_function)
    table.add_row("Tolerate Error Rate: ", str(args.err_rate))
    table.add_row("Target LLVM-IR file: ", args.llvm_file)
    rprint(table)

    try:
        alyz.evaluate(args.llvm_file)
    except KeyboardInterrupt:
        subprocess.run(['killall', args.binary])
