#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Author: Dongqing Sun
# @E-mail: Dongqingsun96@gmail.com
# @Date:   2021-06-07 21:30:43
# @Last Modified by:   Dongqing Sun
# @Last Modified time: 2021-06-07 21:30:45

import os
import sys
import logging
import argparse as ap

from STRIDE.version import __version__
from STRIDE.Deconvolution import DeconvolveParser, Deconvolve
from STRIDE.Plot import PlotParser, Plot
from STRIDE.Clustering import ClusterParser, Clustering
from STRIDE.Integration import IntegrateParser, Integrate
from STRIDE.Mapping import MappingParser, Mapping 


def main():
    """
    Add main function argument parsers.
    """
    parser = ap.ArgumentParser(description = "STRIDE (Spatial TRanscrIptomics DEconvolution by topic modelling) is a cell-type deconvolution tool for spatial transcriptomics by using single-cell transcriptomics data. ")
    parser.add_argument("-v", "--version", action = "store_true", help = "Print version info.")

    subparsers = parser.add_subparsers(dest = "subcommand")
    DeconvolveParser(subparsers)
    PlotParser(subparsers)
    ClusterParser(subparsers)
    IntegrateParser(subparsers)
    MappingParser(subparsers)

    logging.basicConfig(format="%(levelname)s: %(message)s", stream=sys.stderr)
    args = parser.parse_args()

    version = __version__

    if args.version:
        print(version)
        exit(0)
    elif args.subcommand == "deconvolve":
        Deconvolve(sc_count_file = args.sc_count_file, sc_anno_file = args.sc_anno_file, st_count_file = args.st_count_file, model_dir = args.model_dir, 
                   sc_scale_factor = args.sc_scale_factor, st_scale_factor = args.st_scale_factor,
                   out_dir = args.out_dir, out_prefix = args.out_prefix, 
                   normalize = args.normalize, gene_use = args.gene_use, ntopics_list = args.ntopics_list)
    elif args.subcommand == "plot":
        Plot(deconv_res_file = args.deconv_res_file, st_loc_file = args.st_loc_file, sample_id = args.sample_id,
             plot_type = args.plot_type, pt_size = args.pt_size,
             out_dir = args.out_dir, out_prefix = args.out_prefix)
    elif args.subcommand == "cluster":
        Clustering(deconv_res_file = args.deconv_res_file, st_loc_file = args.st_loc_file, sample_id = args.sample_id,
             weight = args.weight, n_clusters = args.n_clusters,
             plot = args.plot, pt_size = args.pt_size,
             out_dir = args.out_dir, out_prefix = args.out_prefix)
    elif args.subcommand == "integrate":
        Integrate(topic_spot_file_list = args.topic_spot_file_list, sample_id_list = args.sample_id_list, 
            deconv_res_file_list = args.deconv_res_file_list, st_loc_file_list = args.st_loc_file_list,
             alpha = args.alpha, plot = args.plot, pt_size = args.pt_size,
             out_dir = args.out_dir, out_prefix = args.out_prefix)
    elif args.subcommand == "map":
        Mapping(topic_st_file = args.topic_st_file, model_dir = args.model_dir, ntop = args.ntop, out_dir = args.out_dir, out_prefix = args.out_prefix)
    else:
        parser.print_help()
        exit(1)
    exit(0)


if __name__ == "__main__":
    main()

