#!/usr/bin/env python

"""
Name: ACSNI
Version: 0.1.11
Author: Chinedu A. Anene, Phd

"""

##############################################################################
import argparse
from pandas import read_csv
import sys, os, gc
from numpy.random import seed
from tensorflow.random import set_seed
from ACSNI.dat import gene_sets, name_generator
from ACSNI.sim import merge_multi, shuffle_exp
from ACSNI.corep import main_model_prediction, merge_minimal_out, check_name, check_duplicates

##############################################################################
parser = argparse.ArgumentParser(description="System biology information extraction for genomics.")
parser.add_argument('-m', '--mad', default = 2, required=False,
                    help="Minimum median absolute deviance for geneSets")
parser.add_argument('-b', '--boot', default = 5, required=False,
                    help="Number of ensemble models to run")
parser.add_argument('-c', '--alpha', default=0.01, required=False,type=float,
                    help="Alpha threshold to make prediction calls")
parser.add_argument('-p', '--lp', default = 0, required=False,
                    help="Dimension of the pathway layer. It is also half of the subprocess,"
                         "set to 0 or default for automatic estimation")
parser.add_argument('-f', '--full', default=1, required=False,
                    help="Run tool in 1=full 0=sub (error only) mode")
parser.add_argument('-i', '--input', type=str, required=True,
                    help="Input expression data (.csv)")
parser.add_argument('-t', '--prior', type=str, required=True,
                    help='Prior matrix, binary')
parser.add_argument('-w', '--weight', default=None, required=False,
                    help='Use weights for the genes', type=str)
parser.add_argument('-s', '--seed', default=68, required=False,
                    help= 'Set seed for reproducibility')
args = parser.parse_args()

##############################################################################
if float(args.mad) <= 1:
    sys.exit("This tool requires informative variation to work"
             ", please increase -m. Ideally, -m should be above 3")
else:
    mad_f = float(args.mad)
bb = int(args.boot)
lp = int(args.lp)
fu = int(args.full)
inp = args.input
pr = args.prior
tf = args.weight
sd = int(args.seed)
alpha_cut = args.alpha


run_info = {
    "m": mad_f, "b": bb, "c": alpha_cut,
    "p": lp, "f": fu, "i": inp, "t": pr,
    "w": tf, "s": sd }

##############################################################################
expression_matrix = read_csv(inp)
check_name(expression_matrix)
check_duplicates(expression_matrix)

##############################################################################

prior_matrix = gene_sets(read_csv(pr))
re = list()

##############################################################################
for z in prior_matrix:
    print("Running for {}".format(z))
    pa_re = inp[:-4] + z + "-" + name_generator(7)

    if not os.path.isdir(pa_re):
        os.mkdir(pa_re)
        print("Results will be saved to {}".format(pa_re))
        os.chdir(os.path.join(os.getcwd(), pa_re))
    else:
        os.chdir(os.path.join(os.getcwd(), pa_re))

    x_code = lp

    if fu == 1:
        for i in range(bb):
            sd_boot = sd * i
            seed(sd_boot)
            set_seed(sd_boot)

            xx_code = main_model_prediction(inp=expression_matrix,
                                            d_list= prior_matrix[z],
                                            gi=z, lp=x_code, s=i,
                                            run=run_info)
            gc.collect()
            if i == 0:
                x_code = xx_code

        merge_multi(os.getcwd(), info=run_info, file=inp)

    elif fu == 0:
        print("WARNING: -F 0 mode only runs one model. Use -f 1 for ensemble")
        seed(sd)
        set_seed(sd)
        results = main_model_prediction(inp=expression_matrix,
                                        d_list=prior_matrix[z],
                                        gi=z, lp=x_code, s=sd,
                                        run=run_info)
        re.append(results)
    else:
        sys.exit("Invalid value set for parameter -f")

if fu == 0:
    merge_minimal_out(re, inp[:-4] + pr[:-4])

##############################################################################
sim_mat = shuffle_exp(expression_matrix)
sim_mat.to_csv("NULL_{}".format(inp), index=False)

# Collect garbage
gc.collect()
