#!/usr/bin/env python

"""
ACSNI-derive
Author: Faraz Khan, PhD

"""
##############################################################################
import argparse
import pandas as pd
import numpy as np
from tensorflow.random import set_seed
from ACSNI import dat, utility, sim
import os, sys, gc

pd.options.mode.chained_assignment = None

##############################################################################
ap = argparse.ArgumentParser(description="De-Novo generation of gene sets")
ap.add_argument('-i', "--input", type=str, required=True,
                help="Input expression data (.csv)")
ap.add_argument("-g", "--gene", required=True,
                help="Gene ID/symbol to analyse")
ap.add_argument("-f", "--bio_file", required=False, default=None,
                help="Gene Bio_type table (.csv)")
ap.add_argument("-b", "--bio_type", required=False,
                help="Gene Bio_type of interest", default=None)
ap.add_argument('-m', '--mad', default = 1.2, required=False, type=float,
                help="Minimum median absolute deviation")
ap.add_argument('-p', '--lp', default=16, required=False, type=int,
                help="Percentage of gene_set for model layers")
ap.add_argument('-c', '--alpha', default=0.01, required=False,type=float,
                help="Alpha threshold to make prediction calls")
ap.add_argument('-ex', '--exclude', default=None, required=False,type=str,
                help="Name of bio_types to exclude in csv format")
ap.add_argument("-t", "--ct", required=False, default=0.60, type=float,
                help="Threshold to use for correlation")
ap.add_argument("-z", "--pc", required=False, type=int,
                help="Number of prior matrix columns", default=5)
ap.add_argument("-u", "--corr_file", required=False,
                help="Pre-computed top correlated genes (.csv)", default=None)
args = vars(ap.parse_args())


##############################################################################
# Setting the arguments
i_dat = args["input"]
i_gene = args["gene"]
i_mad = args["mad"]
i_bio_file = args["bio_file"]
i_bio_type = args["bio_type"]
i_ex = args["exclude"]
p_layer = args["lp"]
i_alpha = args["alpha"]
i_ct = args["ct"]
i_pc = args["pc"]
i_corr = args["corr_file"]



# Argument store
run_info = {
    "i": i_dat, "g": i_gene, "f": i_bio_file,
    "b": i_bio_type, "m": i_mad, "pp": p_layer,
    "c": i_alpha, "ex": i_ex, "t": i_ct,
    "z": i_pc, "u": i_corr }


##############################################################################
# Make biotype
biotype_file = None
if i_bio_file is None and i_bio_type is None:
    bf = False
elif i_bio_file is not None and i_bio_type is not None:
    bf = True

    if i_ex is not None:
        i_ex = pd.read_csv(i_ex, header=None)[0]

    biotype_file = pd.read_csv(i_bio_file)
    biotype_file = pd.DataFrame(data=biotype_file)

    if biotype_file.columns[0] != 'gene' or biotype_file.columns[1] != 'biotype':
        sys.exit("ERROR: Make sure the column names and the order of the biotype "
                 " file is in the right format: 'gene','biotype'")
    else:
        gpb = biotype_file['gene'].str.contains(i_gene).sum()
        if gpb <= 0:
            sys.exit("ERROR: Gene not found in Biotype file")
else:
    sys.exit("ERROR: Biotype file and Biotype filter threshold must be provided if you "
             "intend to filter. Else leave the -f and -b parameters as default")


##############################################################################
# Make correlation
corr_file = None
if i_corr is None:
    bc = False
elif i_corr is not None:
    bc = True
    corr_file = pd.read_csv(i_corr)
    corr_file = pd.DataFrame(data=corr_file)

    if corr_file.columns[0] != 'gene' or corr_file.columns[1] != 'cor':
        sys.exit("ERROR: Make sure the column names and the order of the correlation file "
                 "is in the right format: 'gene','cor'")
    else:
        gpb=corr_file['gene'].str.contains(i_gene).sum()
        if gpb <= 0:
            sys.exit("Gene not found in Correlation file")
else:
    sys.exit("ERROR: Correlation file must be provided if you intend to use modified gene set "
             "for the generation of prior matrix 1. Else leave the -u/--corr_file parameter as default")


##############################################################################
# Expression data
exp_dat = pd.read_csv(i_dat)

# Working directory
pa_re = i_gene + i_dat[:-4] + "-" + dat.name_generator(7)

if not os.path.isdir(pa_re):
    os.mkdir(pa_re)
    print("Results will be saved to {}".format(pa_re))
    os.chdir(os.path.join(os.getcwd(), pa_re))
else:
    os.chdir(os.path.join(os.getcwd(), pa_re))


##############################################################################
# Process expression
exp_dat = pd.DataFrame(data=exp_dat)
gpe=exp_dat['gene'].str.contains(i_gene).sum()
if gpe <= 0:
    sys.exit("ERROR: Gene not found in Expression matrix")


##############################################################################
# Correlation matrix
gcor = utility.get_cor(exp_m=exp_dat, biotypef=biotype_file, cbtype=i_bio_type,
                       madf=i_mad, goi=i_gene, cort=i_ct, corf=corr_file,
                       biotypefilter=bf, docor=bc)

p_layer = gcor.shape[0] * p_layer // 100
run_info["p"] = p_layer

np.random.seed(23)
prior1 = utility.make_prior1(npc=i_pc, gcor=gcor)


##############################################################################
# First iteration of ACSNI
sede = 23
set_seed(sede)
apredic = utility.get_ascni(prior_m=dat.gene_sets(prior1),
                            expression_m=exp_dat,
                            mad=i_mad, p=p_layer, s=sede, a=i_alpha)

if i_bio_file is None and i_bio_type is None:
    predicf = utility.preprocess_run_one(pf=apredic, biotypefilter=bf)
else:
    predicf = utility.preprocess_run_one(pf=apredic, biotypef=biotype_file,
                                         biotypefilter=bf, exclude=i_ex)

pdm = utility.process_run_one(x=predicf, y=i_pc)
prior2 = utility.make_prior2(hc=pdm, gt=i_gene)


##############################################################################
# Second iteration of ACSNI
set_seed(sede)
apredic2 = utility.get_ascni(prior_m=dat.gene_sets(prior2), expression_m=exp_dat,
                             mad=i_mad, p=p_layer, s=sede, a=i_alpha)

predicf2 = utility.preprocess_run_two(pf2=apredic2)
pdm2 = utility.process_run_two(hc2=apredic2)


##############################################################################
# Clean and create database
nn, dd, ac, fd, co = utility.merge_multi(os.getcwd())
utility.save_merged_n_d_ac_fd_co(n=nn, d=dd, ac=ac, fd=fd, path=os.getcwd(),
                                 nfile=i_dat, co=co, run_info=run_info)

##############################################################################
# Null distribution
sim_mat = sim.shuffle_exp(exp_dat)
sim_mat.to_csv("NULL_{}".format(i_dat), index=True)

# Collect garbage
gc.collect()