#!/usr/bin/env python3

"""
Author: Shadi Zabad
Date: May 2022

This is a commandline script that enables users to generate
LD matrices in Zarr format from plink's `.bed` files.
"""

import os.path as osp
import argparse
import magenpy as mgp
from magenpy.utils.system_utils import valid_url
from magenpy.GenotypeMatrix import xarrayGenotypeMatrix, plinkBEDGenotypeMatrix

print(f"""
**********************************************                            
 _ __ ___   __ _  __ _  ___ _ __  _ __  _   _ 
| '_ ` _ \ / _` |/ _` |/ _ \ '_ \| '_ \| | | |
| | | | | | (_| | (_| |  __/ | | | |_) | |_| |
|_| |_| |_|\__,_|\__, |\___|_| |_| .__/ \__, |
                 |___/           |_|    |___/
Modeling and Analysis of Genetics data in python
Version: {mgp.__version__} | Release date: May 2022
Author: Shadi Zabad, McGill University
**********************************************
< Compute LD matrix and output in Zarr format >
""")

parser = argparse.ArgumentParser(description="""
Commandline arguments for LD matrix computation
""")

# General options:
parser.add_argument('--estimator', dest='estimator', type=str, default='windowed',
                    choices={'windowed', 'shrinkage', 'block', 'sample'},
                    help='The LD estimator (windowed, shrinkage, block, sample)')
parser.add_argument('--bfile', dest='bed_file', type=str, required=True,
                    help='The path to a plink BED file')
parser.add_argument('--keep', dest='keep_file', type=str,
                    help='A plink-style keep file to select a subset of individuals to compute the LD matrices.')
parser.add_argument('--extract', dest='extract_file', type=str,
                    help='A plink-style extract file to select a subset of SNPs to compute the LD matrix for.')
parser.add_argument('--backend', dest='backend', type=str, default='xarray',
                    choices={'xarray', 'plink'},
                    help='The backend software used to compute the Linkage-Disequilibrium between variants.')
parser.add_argument('--temp-dir', dest='temp_dir', type=str, default='temp',
                    help='The temporary directory where we store intermediate files.')
parser.add_argument('--output-dir', dest='output_dir', type=str, required=True,
                    help='The output directory where the Zarr formatted LD matrices will be stored.')
parser.add_argument('--min-maf', dest='min_maf', type=float,
                    help='The minimum minor allele frequency for variants included in the LD matrix.')
parser.add_argument('--min-mac', dest='min_mac', type=float,
                    help='The minimum minor allele count for variants included in the LD matrix.')

# Options for the various LD estimators:

# For the windowed estimator:
parser.add_argument('--ld-window', dest='ld_window', type=int,
                    help='Maximum number of neighboring SNPs to consider when computing LD.')
parser.add_argument('--ld-window-kb', dest='ld_window_kb', type=float,
                    help='Maximum distance (in kilobases) between pairs of variants when computing LD.')
parser.add_argument('--ld-window-cm', dest='ld_window_cm', type=float, default=3.,
                    help='Maximum distance (in centi Morgan) between pairs of variants when computing LD.')

# For the block estimator:
parser.add_argument('--ld-blocks', dest='ld_blocks', type=str,
                    help='Path to the file with the LD block boundaries, '
                         'in LDetect format (e.g. chr start stop, tab-separated)')

# For the shrinkage estimator:
parser.add_argument('--genmap-Ne', dest='genmap_ne', type=int,
                    help="The effective population size for the population from which the genetic map was derived.")
parser.add_argument('--genmap-sample-size', dest='genmap_ss', type=int,
                    help="The sample size for the dataset used to infer the genetic map.")
parser.add_argument('--shrinkage-cutoff', dest='shrink_cutoff', type=float,
                    help="The cutoff value below which we assume that the correlation between variants is zero.")

args = parser.parse_args()

# ------------------------------------------------------
# Sanity checks on the parsed arguments:
if args.estimator == 'block':
    if args.ld_blocks is None:
        raise Exception("If you select the [block] LD estimator, make sure that "
                        "you also provide the ld blocks file via the --ld-blocks flag!")
    elif not osp.isfile(args.ld_blocks) and not valid_url(args.ld_blocks):
        raise FileNotFoundError("The LD blocks file does not exist!")
elif args.estimator == 'shrinkage':
    if args.genmap_ne is None:
        raise Exception("If you select the [shrinkage] estimator, you need to specify the "
                        "effective population size via the --genmap-Ne flag!")
    elif args.genmap_ss is None:
        raise Exception("If you select the [shrinkage] estimator, you need to specify the "
                        "sample size for the genetic map via the --genmap-sample-size flag!")

# ------------------------------------------------------
# Extract the arguments for selected estimator:

ld_kwargs = {}

if args.estimator == 'windowed':
    if args.ld_window is not None:
        ld_kwargs['window_size'] = args.ld_window
    if args.ld_window_kb is not None:
        ld_kwargs['kb_window_size'] = args.ld_window_kb
    if args.ld_window_cm is not None:
        ld_kwargs['cm_window_size'] = args.ld_window_cm
elif args.estimator == 'block':
    ld_kwargs['ld_blocks_file'] = args.ld_blocks
elif args.estimator == 'shrinkage':
    if args.genmap_ne is not None:
        ld_kwargs['genetic_map_ne'] = args.genmap_ne
    if args.genmap_ss is not None:
        ld_kwargs['genetic_map_sample_size'] = args.genmap_ss
    if args.shrink_cutoff is not None:
        ld_kwargs['threshold'] = args.shrink_cutoff

# ------------------------------------------------------

# Print out the parsed input commands:
print("> LD estimator:", args.estimator)

print(">>> Parsed estimator characteristics:\n", ld_kwargs)

print("\n\n> Source data:")
print(">>> BED file:", args.bed_file)

if args.keep_file is not None:
    print(">>> Keep samples:", args.keep_file)
if args.extract_file is not None:
    print(">>> Keep variants:", args.extract_file)
if args.min_maf is not None:
    print(">>> Minimum allele frequency:", args.min_maf)
if args.min_mac is not None:
    print(">>> Minimum allele count:", args.min_mac)

print("\n\n> Output:")
print(">>> Temporary directory:", args.temp_dir)
print(">>> Output directory:", args.output_dir)

# ------------------------------------------------------
# Perform the computation:

if args.backend == 'xarray':
    g = xarrayGenotypeMatrix.from_file(args.bed_file, temp_dir=args.temp_dir)
else:
    g = plinkBEDGenotypeMatrix.from_file(args.bed_file, temp_dir=args.temp_dir)

if args.keep_file is not None:
    g.filter_samples(keep_file=args.keep_file)

if args.extract_file is not None:
    g.filter_snps(extract_file=args.extract_file)

if args.min_mac is not None or args.min_maf is not None:
    g.filter_by_allele_frequency(min_maf=args.min_maf, min_mac=args.min_mac)

g.compute_ld(args.estimator, args.output_dir, **ld_kwargs)

# Clean up all intermediate files and directories:
g.cleanup()
print("Done!")
