#!/usr/bin/env python3

"""
Author: Shadi Zabad
Date: May 2022

This script provides functionalities to simulate complex traits on top of
existing genotype data in the form of plink's `.bed` files.

"""

import os.path as osp
from magenpy.simulation.GWASimulator import GWASimulator
from magenpy.utils.system_utils import makedir, get_filenames
import argparse


print("""
**********************************************                            
 _ __ ___   __ _  __ _  ___ _ __  _ __  _   _ 
| '_ ` _ \ / _` |/ _` |/ _ \ '_ \| '_ \| | | |
| | | | | | (_| | (_| |  __/ | | | |_) | |_| |
|_| |_| |_|\__,_|\__, |\___|_| |_| .__/ \__, |
                 |___/           |_|    |___/
Modeling and Analysis of Genetics data in python
Version: 0.0.2 | Release date: May 2022
Author: Shadi Zabad, McGill University
**********************************************
< Simulate complex quantitative or case-control traits >
""")

# --------- Options ---------

parser = argparse.ArgumentParser(description="""
Commandline arguments for the complex trait simulator
""")

parser.add_argument('--bed-files', dest='bed_files', type=str, required=True,
                    help='The BED files containing the genotype data. '
                         'You may use a wildcard here (e.g. "data/chr_*.bed")')
parser.add_argument('--keep', dest='keep_file', type=str,
                    help='A plink-style keep file to select a subset of individuals for simulation.')
parser.add_argument('--extract', dest='extract_file', type=str,
                    help='A plink-style extract file to select a subset of SNPs for simulation.')
parser.add_argument('--backend', dest='backend', type=str, default='xarray',
                    choices={'xarray', 'plink'},
                    help='The backend software used for the computation.')
parser.add_argument('--temp-dir', dest='temp_dir', type=str, default='temp',
                    help='The temporary directory where we store intermediate files.')
parser.add_argument('--output-file', dest='output_file', type=str, required=True,
                    help='The path where the simulated phenotype will be stored '
                         '(no extension needed).')
parser.add_argument('--output-simulated-effects', dest='output_true', action='store_true', default=False,
                    help='Output a table with the true simulated effect size for each variant.')

# Simulation parameters:
parser.add_argument('--h2', dest='h2', type=float, required=True,
                    help='Trait heritability. Ranges between 0. and 1., inclusive.')
parser.add_argument('--mix-prop', '-p', dest='mix_prop', type=str,
                    help='Mixing proportions for the mixture density (comma separated). For example, '
                         'for the spike-and-slab mixture density, with the proportion of causal variants '
                         'set to 0.1, you can specify: "--mix-prop 0.9,0.1 --var-mult 0,1".')
parser.add_argument('--var-mult', '-d', dest='var_mult', type=str,
                    help='Multipliers on the variance for each mixture component.')
parser.add_argument('--likelihood', dest='likelihood', type=str, default='gaussian',
                    choices={'gaussian', 'binomial'},
                    help='The likelihood for the simulated trait. '
                         'Gaussian (e.g. quantitative) or binomial (e.g. case-control).')
parser.add_argument('--prevalence', dest='prevalence', type=float,
                    help='The prevalence of cases (or proportion of positives) for binary traits. '
                         'Ranges between 0. and 1.')

args = parser.parse_args()

# ------------------------------------------------------
# Sanity checks on the inputs:

bed_files = get_filenames(args.bed_files, extension=".bed")
if len(bed_files) < 1:
    raise FileNotFoundError(f"No BED files were identified at the specified location: {args.bed_files}")

if args.mix_prop:
    pi = list(map(float, args.mix_prop.split(",")))
    if args.var_mult:
        d = list(map(float, args.var_mult.split(",")))
    else:
        raise ValueError("Specifying mixing proportions without variance multipliers is not permitted.")
else:
    print("Warning: Mixing proportions not specified. Assuming an infinitesimal architecture "
          "where all variants are causal!")
    pi = [0., 1.]
    d = [0., 1.]

if len(pi) != len(d):
    raise ValueError("The multipliers and mixing proportions must be of the same length!")

# ------------------------------------------------------
# Print out the parsed input commands:
print(f"> Simulating complex trait with {args.likelihood} likelihood...")
print(f">>> Heritability:", args.h2)
print(f">>> Mixing proportions:", pi)
print(f">>> Variance multipliers:", d)

print("\n\n> Source data:")
print(">>> BED files:", args.bed_files)
if args.keep_file:
    print(">>> Keep samples:", args.keep_file)
if args.extract_file:
    print(">>> Keep variants:", args.extract_file)

print("\n\n> Output:")
print(">>> Temporary directory:", args.temp_dir)
print(">>> Output file:", args.output_file)

# ------------------------------------------------------

gs = GWASimulator(bed_files,
                  keep_file=args.keep_file,
                  extract_file=args.extract_file,
                  phenotype_likelihood=args.likelihood,
                  h2=args.h2,
                  pi=pi,
                  d=d,
                  backend=args.backend,
                  temp_dir=args.temp_dir)

print("> Simulating phenotype...")
gs.simulate(reset_beta=True, perform_gwas=False)
pheno_table = gs.to_phenotype_table()

# Write the simulated phenotypes to file:
makedir(osp.dirname(args.output_file))
pheno_table.to_csv(args.output_file + '.SimPheno', sep="\t",
                   index=False, header=False)

if args.output_true:

    sim_effects = gs.to_true_beta_table()
    sim_effects.to_csv(args.output_file + ".SimEffect", sep="\t",
                       index=False)

gs.cleanup()
