#!/usr/bin/python
# -*- coding: utf-8 -*-

import argparse
import logging
import os

# Set  OMP_NUM_THREADS to 1 before importing numpy
os.environ["OMP_NUM_THREADS"] = "1"

import nibabel as nib
from dipy.core.gradients import gradient_table
from dipy.io import read_bvals_bvecs

import bonndit as bd
from bonndit.utils.io import fsl_gtab_to_worldspace
from bonndit.utils import dwmri


def main():
    parser = argparse.ArgumentParser(
        description='This script fits the kurtosis model and computes a number of measures based on it.',
        add_help=False)

    parser.add_argument('indir',
                        help='Folder containing all required input files')

    parser.add_argument('-o', '--outdir',
                        help='Folder in which the output will be saved (default: same as indir)')

    inputfiles = parser.add_argument_group('Custom input filenames (optional)', 'It is not recommended to specify \
    a Specify custom names for input files.')
    inputfiles.add_argument('-d', '--data', default='data.nii.gz',
                            help='Diffusion weighted data (default: data.nii.gz)')
    inputfiles.add_argument('-m', '--brainmask', default='mask.nii.gz',
                            help='Brain mask (default: mask.nii.gz)')

    options = parser.add_argument_group('Options (optional)', '')
    options.add_argument('-t', '--b0thresh', default=0, type=int,
                         help='Every b-value below this threshold is treated as 0')

    flags = parser.add_argument_group('flags (optional)', '')
    flags.add_argument("-h", "--help", action="help",
                       help="Show this help message and exit")
    flags.add_argument('-v', '--verbose', action='store_true',
                       help='Activate progress bars and console logging')
    flags.add_argument('-R', '--fitonly', action='store_true',
                       help='Calculate and save only the kurtosis_fit functions')

    multiprocessing = parser.add_argument_group('multiprocessing (optional)',
                                                'Configure the multiprocessing behaviour')
    multiprocessing.add_argument('-w', '--workers', default=None, type=int,
                                 help='Number of cpus (default: all available cpus)')

    log = parser.add_argument_group('logging (optional)',
                                    'Configure the logging behaviour')
    log.add_argument('-L', '--loglevel', choices=['INFO', 'WARNING', 'ERROR'],
                     default='INFO',
                     help='Specify the logging level for the console')

    args = parser.parse_args()

    # Create outdir if it does not exists
    indir = args.indir
    if not args.outdir:
        outdir = indir
    else:
        outdir = args.outdir

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    levels = {'INFO': logging.INFO,
              'WARNING': logging.WARNING,
              'ERROR': logging.ERROR}

    # Logging setup for file
    logging.basicConfig(filename=os.path.join(outdir, 'kurtosis.log'),
                        format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
                        datefmt='%y-%m-%d %H:%M',
                        level=levels[args.loglevel],
                        filemode='w')

    # Console logging if verbose flag is set
    if args.verbose:
        # define a Handler which writes INFO messages or higher to the sys.stderr
        console = logging.StreamHandler()
        console.setLevel(levels[args.loglevel])
        # set a format which is simpler for console use
        formatter = logging.Formatter(
            '%(name)-12s: %(levelname)-8s %(message)s')
        # tell the handler to use this format
        console.setFormatter(formatter)
        # add the handler to the root logger
        logging.getLogger('').addHandler(console)

    # Raise error and log if bvals/bvecs not found. This is done inside bd.load
    # for data and brainmask
    for f in ['bvals', 'bvecs']:
        filepath = os.path.join(indir, f)
        if not os.path.isfile(filepath):
            msg = 'No such file or directory: "{}"'.format(filepath)
            logging.error(msg)
            raise FileNotFoundError(msg)

    # Load mask
    mask = bd.load(os.path.join(indir, args.brainmask)).get_fdata()

    # Load diffusion weighted data
    data = bd.load(os.path.join(indir, args.data))

    # Load bvals and bvecs and initialize a GradientTable object
    bvals, bvecs = read_bvals_bvecs(os.path.join(indir, "bvals"),
                                    os.path.join(indir, "bvecs"))
    gtab = gradient_table(bvals, bvecs, b0_threshold=args.b0thresh)
    if gtab.b0_threshold < min(gtab.bvals):
        msg = "The specified b0 threshold is {}. The minimum b-value in " \
              "the gradient table is {}. Please specify an appropriate " \
              "b0 threshold.".format(gtab.b0_threshold, min(gtab.bvals))
        raise ValueError(msg)

    logging.info('Input loaded.')

    # Flip sign of x-coordinate if affine determinant is positive and rotate to worldspace
    gtab = fsl_gtab_to_worldspace(gtab, data.affine)
    logging.info('Rotation to worldspace finished')

    # We need this Meta object for saving later
    base_filename = os.path.join(indir, args.data).rstrip(".gz").rstrip(".nii")
    try:
        _, _, meta = dwmri.load(base_filename + '.nii.gz')
    except FileNotFoundError:
        try:
            _, _, meta = dwmri.load(base_filename + '.nii')
        except FileNotFoundError as e:
            raise FileNotFoundError(e)


    # Check if model fit is already in the output folder
    if not args.fitonly:
        if os.path.exists(os.path.join(outdir, "kurtosis_fit.npz")):
            fit = bd.DkiFit.load(os.path.join(outdir, "kurtosis_fit.npz"))
            logging.info('Existing kurtosis fits loaded.')

        else:
            model = bd.DkiModel(gtab)
            fit = model.fit(data.get_fdata(), mask, verbose=args.verbose,
                            cpus=args.workers, desc='Estimate DKI tensors')
            fit.save(os.path.join(outdir, "kurtosis_fit.npz"))
            logging.info('Kurtosis fit estimated and saved.')

    # Force recalculate the kurtosis_fit if fitonly is specified
    else:
        model = bd.DkiModel(gtab)
        fit = model.fit(data.get_fdata(), mask, verbose=args.verbose,
                        cpus=args.workers, desc='Estimate DKI tensors')
        fit.save(os.path.join(outdir, "kurtosis_fit.npz"))
        logging.info('Kurtosis fit recalculated and saved.')

    # Calculate measures if 'fitonly' is not set
    if not args.fitonly:
        logging.info('Calculating kurtosis measures. This may take a while.')
        nib.save(nib.Nifti1Image(fit.diffusivity_axial, affine=data.affine),
                 os.path.join(outdir, "da.nii"))
        nib.save(nib.Nifti1Image(fit.diffusivity_radial, affine=data.affine),
                 os.path.join(outdir, "dr.nii"))
        nib.save(nib.Nifti1Image(fit.diffusivity_mean, affine=data.affine),
                 os.path.join(outdir, "dm.nii"))

        nib.save(nib.Nifti1Image(fit.kurtosis_axial, affine=data.affine),
                 os.path.join(outdir, "ka.nii"))
        nib.save(nib.Nifti1Image(fit.kurtosis_radial, affine=data.affine),
                 os.path.join(outdir, "kr.nii"))
        nib.save(nib.Nifti1Image(fit.kurtosis_mean, affine=data.affine),
                 os.path.join(outdir, "km.nii"))

        nib.save(nib.Nifti1Image(fit.fractional_anisotropy,
                                 affine=data.affine),
                 os.path.join(outdir, "fa.nii"))

    logging.info('Success!')


if __name__ == "__main__":
    main()
