#! /usr/bin/env python

from mushi.ksfs import kSFS
import mushi.composition as cmp

import matplotlib.pyplot as plt
import argparse
import pickle
import configparser
import numpy as np
import pandas as pd


def main():
    """
    usage: python mushi -h
    """
    parser = argparse.ArgumentParser(description='write snps with kmer context'
                                                 ' to stdout')
    parser.add_argument('ksfs', type=str, default=None,
                        help='path to k-SFS file')
    parser.add_argument('masked_genome_size_file', type=str,
                        help='path to file containing masked genome size in '
                             'nucleotides')
    parser.add_argument('config', type=str, help='path to config file')
    parser.add_argument('outbase', type=str, default=None,
                        help='base name for output files')

    args = parser.parse_args()

    # load k-SFS
    ksfs_df = pd.read_csv(args.ksfs, sep='\t', index_col=0)
    assert np.isnan(ksfs_df.values).sum() == 0
    n = ksfs_df.shape[0] + 1
    ksfs = kSFS(X=ksfs_df.values, mutation_types=ksfs_df.columns)

    # parse configuration file if present
    config = configparser.ConfigParser()
    config.read(args.config)

    # change points for time grid
    first = config.getfloat('change points', 'first')
    last = config['change points'].getfloat('last')
    npts = config['change points'].getint('npts')
    change_points = np.logspace(np.log10(first),
                                np.log10(last),
                                npts)

    # mask sites
    clip_low = config.getint('loss', 'clip_low', fallback=None)
    clip_high = config.getint('loss', 'clip_high', fallback=None)
    if clip_high or clip_low:
        assert clip_high is not None and clip_low is not None
        mask = np.array([True if (clip_low <= i < n - clip_high - 1)
                         else False
                         for i in range(n - 1)])
    else:
        mask = None

    # mutation rate estimate
    with open(args.masked_genome_size_file) as f:
        masked_genome_size = int(f.read())
    μ0 = config.getfloat('population', 'u') * masked_genome_size

    # generation time
    t_gen = config.getfloat('population', 't_gen', fallback=None)

    # parameter dict for η regularization
    η_regularization = {key: config.getfloat('eta regularization', key)
                        for key in config['eta regularization']}

    # parameter dict for μ regularization
    μ_regularization = {key: config.getfloat('mu regularization', key)
                        for key in config['mu regularization']
                        if key.startswith('beta_')}
    if 'hard' in config['mu regularization']:
        μ_regularization['hard'] = config.getboolean('mu regularization',
                                                     'hard')

    # parameter dict for convergence parameters
    convergence = {key: config.getint('convergence', key)
                   if key.endswith('_iter')
                   else config.getfloat('convergence', key)
                   for key in config['convergence']}
    # parameter dict for loss parameters
    loss = dict(mask=mask)
    if 'loss' in config['loss']:
        loss['loss'] = config.get('loss', 'loss')

    print('sequential inference of η(t) and μ(t)\n', flush=True)
    ksfs.infer_history(change_points, μ0, **loss, **η_regularization,
                       **μ_regularization, **convergence)

    plt.figure(figsize=(7, 9))
    plt.subplot(321)
    ksfs.plot_total()
    plt.subplot(322)
    ksfs.η.plot(t_gen=t_gen,
                # ds='steps-post'
                )
    plt.subplot(323)
    ksfs.plot(clr=True)
    plt.subplot(324)
    ksfs.μ.plot(t_gen=t_gen, clr=True, alpha=0.5)
    plt.subplot(325)
    if t_gen:
        plt.plot(t_gen * ksfs.η.change_points, ksfs.tmrca_cdf(ksfs.η))
        plt.xlabel('$t$ (years ago)')
    else:
        plt.plot(ksfs.η.change_points, ksfs.tmrca_cdf(ksfs.η))
        plt.xlabel('$t$ (generations ago)')
    plt.ylabel('TMRCA CDF')
    plt.ylim([0, 1])
    plt.xscale('log')
    plt.tick_params(axis='x', which='minor')
    plt.tight_layout()
    plt.subplot(326)
    Z = cmp.clr(ksfs.μ.Z)
    plt.plot(range(1, 1 + min(Z.shape)),
             np.linalg.svd(Z, compute_uv=False), '.')
    plt.xlabel('singular value rank')
    plt.xscale('log')
    plt.ylabel('singular value')
    plt.yscale('log')
    plt.tight_layout()
    plt.savefig(f'{args.outbase}.fit.pdf')

    # pickle the final ksfs (which contains all the inferred history info)
    with open(f'{args.outbase}.pkl', 'wb') as f:
        pickle.dump(ksfs, f)


if __name__ == '__main__':
    main()
