#!/usr/bin/python
# -*- coding: utf-8 -*-

import argparse
import logging
import os

# Set  OMP_NUM_THREADS to 1 before importing numpy
os.environ["OMP_NUM_THREADS"] = "1"

import nibabel as nib
from dipy.core.gradients import gradient_table
from dipy.io import read_bvals_bvecs

import bonndit as bd
from bonndit.utils.io import fsl_vectors_to_worldspace, fsl_gtab_to_worldspace
from bonndit.utils import dwmri, fields
from bonndit.deconv.shoredeconv import fa_guided_mask


def main():
    parser = argparse.ArgumentParser(
        description='Constrained Spherical Deconvolution', add_help=False)

    parser.add_argument('indir',
                        help='Folder containing all required input files')

    parser.add_argument('-o', '--outdir',
                        help='Folder in which the output will be saved (default: same as indir)')

    inputfiles = parser.add_argument_group('Custom input filenames', 'It is not recommended to specify \
    a Specify custom names for input files.')
    inputfiles.add_argument('-d', '--data', default='data.nii.gz',
                            help='Diffusion weighted data (default: data.nii.gz)')
    inputfiles.add_argument('-e', '--dtivecs', default='dti_V1.nii.gz',
                            help='First eigenvectors of a DTI model (default: dti_V1.nii.gz)')
    inputfiles.add_argument('-a', '--dtifa', default='dti_FA.nii.gz',
                            help='Fractional anisotropy values from a DTI model (default: dti_FA.nii.gz)')
    inputfiles.add_argument('-m', '--brainmask', default='mask.nii.gz',
                            help='Brain mask (default: mask.nii.gz)')
    inputfiles.add_argument('-W', '--wmmask', default='fast_pve_2.nii.gz',
                            help='White matter mask (default: fast_pve_2.nii.gz)')

    flags = parser.add_argument_group('flags (optional)', '')
    flags.add_argument("-h", "--help", action="help",
                       help="Show this help message and exit")
    flags.add_argument('-v', '--verbose', action='store_true',
                       help='Activate progress bars and console logging')
    flags.add_argument('-R', '--responseonly', action='store_true',
                       help='Calculate and save only the response functions')
    flags.add_argument('-M', '--tissuemasks', action='store_true',
                       help='Output the DTI improved tissue masks (csf/gm/wm)')

    shopts = parser.add_argument_group('sh options (optional)', 'Optional arguments for the computation of \
    the spherical harmonics white matter response function')
    shopts.add_argument('-k', '--kernel', choices=["rank1", "delta"],
                        default="rank1", type=str,
                        help='Kernel type (default: rank1)')
    shopts.add_argument('-r', '--order', default=4, type=int,
                        help='Order of the spherical harmonics basis (default: 4)')
    shopts.add_argument('-f', '--fawm', default=0.7, type=float,
                        help='White matter fractional anisotropy threshold (default: 0.7)')

    deconvopts = parser.add_argument_group('deconvolution options (optional)',
                                           '')
    deconvopts.add_argument('-C', '--constraint',
                            choices=['hpsd', 'nonneg', 'none'], default='hpsd',
                            help='Constraint for the fODFs (default: hpsd)')

    multiprocessing = parser.add_argument_group('multiprocessing (optional)', 'Configure the multiprocessing behaviour \
    (only supported for Python 3)')
    multiprocessing.add_argument('-w', '--workers', default=None, type=int,
                                 help='Number of cpus (default: all available cpus)')

    log = parser.add_argument_group('logging (optional)',
                                    'Configure the logging behaviour')
    log.add_argument('-L', '--loglevel', choices=['INFO', 'WARNING', 'ERROR'],
                     default='INFO',
                     help='Specify the logging level for the console')

    args = parser.parse_args()

    # Create outdir if it does not exists
    indir = args.indir
    if not args.outdir:
        outdir = indir
    else:
        outdir = args.outdir

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    levels = {'INFO': logging.INFO,
              'WARNING': logging.WARNING,
              'ERROR': logging.ERROR}

    # Logging setup for file
    logging.basicConfig(filename=os.path.join(outdir, 'stdeconv.log'),
                        format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
                        datefmt='%y-%m-%d %H:%M',
                        level=levels[args.loglevel],
                        filemode='w')

    # Console logging if verbose flag is set
    if args.verbose:
        # define a Handler which writes INFO messages or higher to the sys.stderr
        console = logging.StreamHandler()
        console.setLevel(levels[args.loglevel])
        # set a format which is simpler for console use
        formatter = logging.Formatter(
            '%(name)-12s: %(levelname)-8s %(message)s')
        # tell the handler to use this format
        console.setFormatter(formatter)
        # add the handler to the root logger
        logging.getLogger('').addHandler(console)

    logging.info('stdeconv has been called with:')
    param_string = 'Order: {}, FAWM: {}, Constraint: {}, Kernel {}'
    logging.info(param_string.format(args.order, args.fawm, args.constraint,
                                     args.kernel))

    # Check whether all specified files exist
    for f in [args.brainmask, args.data, args.dtifa, args.dtivecs,
              'bvals', 'bvecs', ]:
        filepath = os.path.join(indir, f)
        if not os.path.isfile(filepath):
            msg = 'No such file or directory: "{}"'.format(filepath)
            logging.error(msg)
            raise FileNotFoundError(msg)

    # Load fractional anisotropy
    dti_fa = bd.load(os.path.join(indir, args.dtifa))

    # Load DTI mask
    dti_mask = bd.load(os.path.join(indir, args.brainmask))

    # Load tissue segmentation masks (now optional)
    try:
        wm_mask = bd.load(os.path.join(indir, args.wmmask))
    except:
        wm_mask = None

    # Load first eigenvectors of a precalculated diffusion tensor
    dti_vecs = bd.load(os.path.join(indir, args.dtivecs))

    # Load diffusion weighted data
    data = bd.load(os.path.join(indir, args.data))

    # Load bvals and bvecs and initialize a GradientTable object
    bvals, bvecs = read_bvals_bvecs(os.path.join(indir, "bvals"),
                                    os.path.join(indir, "bvecs"))
    gtab = gradient_table(bvals, bvecs)

    logging.info('Input loaded.')

    wm_mask = fa_guided_mask(wm_mask, dti_fa, dti_mask,fa_lower_thresh=0.7)
    logging.info('Fractional anisotropy based tissue mask created.')

    # Flip sign of x-coordinate if affine determinant is positive and rotate to worldspace
    gtab = fsl_gtab_to_worldspace(gtab, data.affine)
    dti_vecs = fsl_vectors_to_worldspace(dti_vecs)
    logging.info('Rotation to worldspace finished')

    # We need this Meta object for saving later
    base_filename = os.path.join(indir, args.data).rstrip(".gz").rstrip(".nii")
    try:
        _, _, meta = dwmri.load(base_filename + '.nii.gz')
    except FileNotFoundError:
        try:
            _, _, meta = dwmri.load(base_filename + '.nii')
        except FileNotFoundError as e:
            raise FileNotFoundError(e)

    if args.tissuemasks:
        nib.save(wm_mask, os.path.join(outdir, 'wm_mask.nii.gz'))

    # Check if response is already in the output folder
    if not args.responseonly:
        if os.path.exists(os.path.join(outdir, "response.npz")):
            fit = bd.ShResponse.load(os.path.join(outdir, "response.npz"))
            logging.info('Existing response functions loaded.')

        else:
            model = bd.ShResponseEstimator(gtab, args.order)
            fit = model.fit(data, dti_vecs, wm_mask, verbose=args.verbose,
                            cpus=args.workers)
            fit.save(os.path.join(outdir, "response.npz"))
            logging.info('Response functions estimated and saved.')

    # Force recalculate the response if response only is specified
    else:
        model = bd.ShResponseEstimator(gtab, args.order)
        fit = model.fit(data, dti_vecs, wm_mask, verbose=args.verbose,
                        cpus=args.workers)
        fit.save(os.path.join(outdir, "response.npz"))
        logging.info('Response functions recalculated and saved.')

    # Deconvolution if 'responseonly' is not set
    if not args.responseonly:
        out, wmout = fit.fodf(data, pos=args.constraint,
                              mask=dti_mask, kernel=args.kernel,
                              verbose=args.verbose,
                              cpus=args.workers)
        logging.info(
            'Signal deconvolved with single tissue response function.')

        fields.save_tensor(os.path.join(outdir, "odf.nrrd"), out,
                           mask=dti_mask.get_fdata(), meta=meta)
        logging.info('fODFs saved.')

        # Save volumes
        fields.save_scalar(os.path.join(outdir, "wmvolume.nrrd"),
                           wmout, meta)
        logging.info('Volume fractions saved.')

    logging.info('Success!')


if __name__ == "__main__":
    main()

