#!/usr/bin/env python
import sys
import os
import numpy as np
from tqdm.auto import tqdm
import torch

import pyro
import pyro.infer
import pyro.optim
import torch
from pyro.infer.autoguide import AutoNormal

import pickle as pkl
import argparse
import beige.model as m
import beige.load_data
import beret as be

sys.path.insert(1, os.path.join(sys.path[0], '..'))
pyro.set_rng_seed(101)

def get_sd_quantile(quantiles):
    # 16-84% quantile is 2 sigma.
    return (quantiles[83]-quantiles[15])/2


def run_inference(model, data,
                  init_scale=0.01,
                  initial_lr=0.01,
                  gamma=0.1,
                  num_steps=2000):
    guide = AutoNormal(model, init_scale=init_scale)
    lrd = gamma ** (1 / num_steps)
    svi = pyro.infer.SVI(
        model=model,
        guide=guide,
        optim=pyro.optim.ClippedAdam({"lr": initial_lr, 'lrd': lrd}),
        loss=pyro.infer.Trace_ELBO())

    #mus, mu_sds, sds, losses = [], [], [], []
    losses = []
    for t in tqdm(range(num_steps)):
        loss = svi.step(data)
        if t % 100 == 0:
            print("loss {} @ iter {}".format(loss, t))
        losses.append(loss)
        # mus.append(pyro.param("AutoNormal.locs.mu_alleles").clone())
        # mu_sds.append(pyro.param("AutoNormal.scales.mu_alleles").clone())
        # sds.append((pyro.param("AutoNormal.scales.sd_alleles").clone()))
    quantiles=guide.quantiles(torch.arange(0.01, 1, 0.01))
    mu = quantiles['mu_alleles'][50].detach()[:,0].numpy()
    mu_sd = get_sd_quantile(quantiles['mu_alleles'].detach())[:,0].numpy()
    sd = quantiles['sd_alleles'][50].detach()[:,0].numpy()

    param_history_dict = {
        "loss": losses,
        "params": pyro.get_param_store(),
        "quantiles": guide.quantiles(torch.arange(0.01, 1, 0.01)),
        "mu": mu,
        "mu_sd": mu_sd,
        "z_score": mu/mu_sd,
        "sd": sd
    }
    return(param_history_dict)

def main(args):
    if args.cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    else:
        torch.set_default_tensor_type(torch.FloatTensor)
    if args.prefix == "":
        args.prefix = "."
    os.makedirs(args.prefix, exist_ok=True)
    args.prefix = args.prefix + "/" + \
        os.path.basename(args.bdata_path).rsplit(".h5ad", 1)[0]
    bdata = be.read_h5ad(args.bdata_path)
    print("Done loading data. Preprocessing...")

    bdata.condit['rep'] = bdata.condit['rep'].astype('category')
    total_reps = bdata.condit['rep'].unique()
    device = None
    if args.mcmc:
        device = 'cuda:0'

    if bdata.shape[1] != 0:
        ndata = beige.load_data.NData(
            bdata, 5, "topbot",
            device=device,
            sample_mask_column=args.sample_mask_column,
            fit_a0=True)
        ndata.rep_index = np.where(
            total_reps.isin(bdata.condit['rep'].unique()))[0]

    model_dict = {
        "A": lambda data: m.NormalModel(data, use_bcmatch = False),
        "B0": lambda data: m.MixtureNormalFittedPi(data, alpha_prior=10, use_bcmatch = False),
        "B": lambda data: m.MixtureNormal(data, alpha_prior=10, use_bcmatch = False),
        "B2": lambda data: m.MixtureNormalRepPi(data, alpha_prior=10, use_bcmatch = False),
    }
    
    model = model_dict[args.model_id]

    print("Running inference for model {}...".format(args.model_id))
    param_history_dict = run_inference(model, ndata)
    outfile_path = "{}.model{}.result.pkl".format(args.prefix, args.model_id)
    print("Done running inference. Writing result at {}...".format(outfile_path))
    with open("{}.model{}.result.pkl".format(args.prefix, args.model_id), "wb") as handle:
        pkl.dump(param_history_dict, handle)
    print("Done!")

def check_args(args):
    # identify model to use
    if args.perfect_edit:
        if args.fit_pi or args.rep_pi: raise ValueError("Incompatible model specification: can't assume perfect edit and fit the edit rate or replicate scaling factor.")
        if not args.guide_activity_column is None: raise ValueError("Can't use the guide activity column while constraining perfect edit.")
        args.model_id == "A"
    elif args.fit_pi:
        if args.rep_pi: raise NotImplementedError("Fitting both pi and the replicate scaling factor is not supported.")
        args.model_id == "B0"
    elif args.rep_pi:
        args.model_id == "B2"
    else:
        args.model_id == "B"
    return(args)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Run model on data.')
    parser.add_argument('bdata_path', type=str,
                        help='Path of an ReporterScreen object')
    parser.add_argument('--fit-pi', '-f', action='store_true', default=False, help='Fit pi (editing rate) of target variant. Observed editing activity is ignored.')
    parser.add_argument('--rep-pi', '-r', action='store_true', default=False, help = "Fit replicate specific scaling factor. Recommended to set as True if you expect variable editing activity across biological replicates.")
    parser.add_argument('--perfect-edit', '-c', action='store_true', default=False, help= "Assume perfect editing rate for all guides.")
    parser.add_argument('--guide_activity_column', '-a', type=str, default=None, help="Column in ReporterScreen.guide DataFrame showing the editing rate estimated via external tools")
    parser.add_argument('--pi-prior-weight', '-w', type=float, default=1.0, help = "Prior weight for editing rate")
    parser.add_argument('--prefix', '-p', default='',
                        help='prefix to save the result')
    parser.add_argument('--sorting_bin_upper_quantile_column', '-uq', help = "Column name with upper quantile values of each sorting bin in [Reporter]Screen.condit (or AnnData.var)")
    parser.add_argument('--sorting_bin_lower_quantile_column', '-lq', help = "Column name with lower quantile values of each sorting bin in [Reporter]Screen.condit (or AnnData var)")
    parser.add_argument(
        "--cuda", action="store_true", default=False, help="run on GPU"
    )
    parser.add_argument(
        "--sample-mask-column", type=str, default=None, help="Name of the column indicating the sample mask in [Reporter]Screen.condit (or AnnData.var). Sample is ignored if the value in this column is 0. This can be used to mask out low-quality samples."
    )
    args = parser.parse_args()
    args = check_args(args)
    main(args)
