#!/usr/bin/python
# -*- coding: utf-8 -*-

import argparse
import logging
import os

# Set  OMP_NUM_THREADS to 1 before importing numpy
os.environ["OMP_NUM_THREADS"] = "1"

import nibabel as nib
from math import pi
from dipy.core.gradients import gradient_table
from dipy.io import read_bvals_bvecs

import bonndit as bd
from bonndit.deconv.shoredeconv import fa_guided_mask
from bonndit.utils.io import fsl_vectors_to_worldspace, fsl_gtab_to_worldspace, metadata_checker
from bonndit.utils import dwmri, fields

def main():
	parser = argparse.ArgumentParser(
		description='This script computes fiber orientation distribution '
					'functions (fODFs) as described in "Versatile, Robust and '
					'Efficient Tractography With Constrained Higher Order '
					'Tensor fODFs" by Ankele et al. (2017). It is assumed that'
					' the input data is saved using FSL.', add_help=False)

	parser.add_argument('indir',
						help='Folder containing all required input files')

	parser.add_argument('-o', '--outdir',
						help='Folder in which the output will be saved (default: same as indir)')

	inputfiles = parser.add_argument_group('Custom input filenames', 'It is not recommended to specify \
    custom names for input files.')
	inputfiles.add_argument('-d', '--data', default='data.nii.gz',
							help='Diffusion weighted data (default: data.nii.gz)')
	inputfiles.add_argument('-e', '--dtivecs', default='dti_V1.nii.gz',
							help='First eigenvectors of a DTI model (default: dti_V1.nii.gz)')
	inputfiles.add_argument('-a', '--dtifa', default='dti_FA.nii.gz',
							help='Fractional anisotropy values from a DTI model (default: dti_FA.nii.gz)')
	inputfiles.add_argument('-m', '--brainmask', default='nodif_brain_mask.nii.gz',
							help='Brain mask (default: nodif_brain_mask.nii.gz)')
	inputfiles.add_argument('-Ma', '--masks', default='fast_first.nii.gz',
							help='Tissue masks containing white matter, gray matter and csf mask. (default: fast_first.nii.gz)')
	inputfiles.add_argument('-W', '--wmmask', default='fast_pve_2.nii.gz',
							help='White matter mask (default: fast_pve_2.nii.gz)')
	inputfiles.add_argument('-G', '--gmmask', default='fast_pve_1.nii.gz',
							help='Gray matter mask (default: fast_pve_1.nii.gz)')
	inputfiles.add_argument('-F', '--csfmask', default='fast_pve_0.nii.gz',
							help='Cerebrospinal fluid mask (default: fast_pve_0.nii.gz)')
	inputfiles.add_argument('-E', '--response', default='response.npz', help='Precalculated response function. '
																			 'If not exist it will be calculated.'
																			 '(default: response.npz)')


	flags = parser.add_argument_group('flags (optional)', '')
	flags.add_argument("-h", "--help", action="help", help="Show this help message and exit")
	flags.add_argument('-v', '--verbose',
					   help='Activate progress bars and console logging', default=True)
	flags.add_argument('-R', '--responseonly', action='store_true',
					   help='Calculate and save only the response functions')
	flags.add_argument('-M', '--tissuemasks', action='store_true',
					   help='Output the DTI improved tissue masks (csf/gm/wm)')

	shoreopts = parser.add_argument_group('shore options (optional)', 'Optional arguments for the computation of \
    the shore response functions')
	shoreopts.add_argument('-k', '--kernel', choices=["rank1", "delta"],
						   default="rank1", type=str,
						   help='Kernel type (default: rank1)')
	shoreopts.add_argument('-r', '--order', default=4, type=int,
						   help='Order of the shore basis (default: 4)')
	shoreopts.add_argument('-z', '--zeta', default=700, type=float,
						   help='Radial scaling factor (default: 700)')
	shoreopts.add_argument('-t', '--tau', default=1 / (4 * pi ** 2),
						   type=float,
						   help='q-scaling (default: 1 / (4 * math.pi ** 2)')
	shoreopts.add_argument('-f', '--fawm', default=0.7, type=float,
						   help='White matter fractional anisotropy threshold (default: 0.7)')

	deconvopts = parser.add_argument_group('deconvolution options (optional)', '')
	deconvopts.add_argument('-C', '--constraint', choices=['hpsd', 'nonneg', 'none'], default='hpsd',
							help='Constraint for the fODFs (default: hpsd)')

	multiprocessing = parser.add_argument_group('multiprocessing (optional)', 'Configure the multiprocessing behaviour \
    (only supported for Python 3)')
	multiprocessing.add_argument('-w', '--workers', default=None, type=int,
								 help='Number of cpus (default: all available cpus)')

	log = parser.add_argument_group('logging (optional)', 'Configure the logging behaviour')
	log.add_argument('-L', '--loglevel', choices=['INFO', 'WARNING', 'ERROR'],
					 default='INFO',
					 help='Specify the logging level for the console')

	args = parser.parse_args()

	# Create outdir if it does not exists
	indir = args.indir
	if not args.outdir:
		outdir = indir
	else:
		outdir = args.outdir

	if not os.path.exists(outdir):
		os.makedirs(outdir)

	levels = {'INFO': logging.INFO,
			  'WARNING': logging.WARNING,
			  'ERROR': logging.ERROR}

	# Logging setup for file
	logging.basicConfig(filename=os.path.join(outdir, 'mtdeconv.log'),
						format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
						datefmt='%y-%m-%d %H:%M',
						level=levels[args.loglevel],
						filemode='w')

	# Console logging if verbose flag is set
	if args.verbose:
		# define a Handler which writes INFO messages or higher to the sys.stderr
		console = logging.StreamHandler()
		console.setLevel(levels[args.loglevel])
		# set a format which is simpler for console use
		formatter = logging.Formatter(
			'%(name)-12s: %(levelname)-8s %(message)s')
		# tell the handler to use this format
		console.setFormatter(formatter)
		# add the handler to the root logger
		logging.getLogger('').addHandler(console)

	logging.info('mtdeconv has been called with:')
	param_string = 'Order: {} Zeta: {}, Tau: {}, FAWM: {}, Constraint: {}, ' \
				   'Kernel {}'
	logging.info(param_string.format(args.order, args.zeta, args.tau,
									 args.fawm, args.constraint, args.kernel))

	# If deconvolution is wanted check whether a response exists already
	response_available = False
	if not args.responseonly:
		if os.path.exists(os.path.join(outdir, args.response)):
			response_available = True

	# If response available load data and response
	if response_available:
		# Only mask, and data are needed. Gradients are saved in response.
		for f in [args.brainmask, 'bvals', 'bvecs', ]:
			filepath = os.path.join(indir, f)
			if not os.path.isfile(filepath):
				msg = 'No such file or directory: "{}"'.format(filepath)
				logging.error(msg)
				raise FileNotFoundError(msg)

		# Load existing response
		fit = bd.ShoreMultiTissueResponse.load(
			os.path.join(outdir, args.response))
		logging.info('Existing response functions loaded.')

		# Load DTI mask
		dti_mask = bd.load(os.path.join(indir, args.brainmask))

		# Load diffusion weighted data
		data = bd.load(os.path.join(indir, args.data))

		logging.info('Input loaded.')

	# If response not available load files and compute response
	else:
		# check if fast+first output is available. Otherwise fall back to fast only!
		all_files = [args.brainmask, args.dtifa, args.dtivecs, 'bvals', 'bvecs', ]
		filepath = os.path.join(indir, args.masks)
		if not os.path.isfile(filepath):
			ff = 0
			all_files += [args.csfmask, args.gmmask, args.wmmask]
		else:
			ff = 1

		# Test whether all required files are available
		for f in all_files:
			filepath = os.path.join(indir, f)
			if not os.path.isfile(filepath):
				msg = 'No such file or directory: "{}"'.format(filepath)
				logging.error(msg)
				raise FileNotFoundError(msg)

		# Load fractional anisotropy
		dti_fa = bd.load(os.path.join(indir, args.dtifa))

		# Load DTI mask
		dti_mask = bd.load(os.path.join(indir, args.brainmask))
		# Load tissue segmentation masks. If fast+first is available take it otherwise fall back to fast
		if ff:
			masks = bd.load(os.path.join(indir, args.masks))
			csf_mask = nib.Nifti1Image(masks.get_fdata()[:, :, :, 3], masks.affine)
			gm_mask = nib.Nifti1Image(masks.get_fdata()[:, :, :, 0], masks.affine)
			wm_mask = nib.Nifti1Image(masks.get_fdata()[:, :, :, 2], masks.affine)
		else:
			csf_mask = bd.load(os.path.join(indir, args.csfmask))
			gm_mask = bd.load(os.path.join(indir, args.gmmask))
			wm_mask = bd.load(os.path.join(indir, args.wmmask))




		# Load first eigenvectors of a precalculated diffusion tensor
		dti_vecs = bd.load(os.path.join(indir, args.dtivecs))

		# Load diffusion weighted data
		data = bd.load(os.path.join(indir, args.data))
		# Check if metadata of all input matches the metadata of data
		af = data.affine
		dti_fa, dti_mask, csf_mask, gm_mask, wm_mask, dti_vecs = metadata_checker(af, dti_fa, dti_mask, csf_mask, gm_mask, wm_mask, dti_vecs)

		# Load bvals and bvecs and initialize a GradientTable object
		bvals, bvecs = read_bvals_bvecs(os.path.join(indir, "bvals"),
										os.path.join(indir, "bvecs"))
		gtab = gradient_table(bvals, bvecs)

		logging.info('Input loaded.')

		# Create tissue masks based on fractional anisotropy values

		wm_mask = fa_guided_mask(wm_mask, dti_fa, dti_mask,
									 tissue_threshold=0.95,
									 fa_lower_thresh=0.7)
		gm_mask = fa_guided_mask(gm_mask, dti_fa, dti_mask,
								 tissue_threshold=0.95,
								 fa_upper_thresh=0.2)
		csf_mask = fa_guided_mask(csf_mask, dti_fa, dti_mask,
								  tissue_threshold=0.95,
								  fa_upper_thresh=0.2)

		if args.tissuemasks:
			nib.save(wm_mask, os.path.join(outdir, 'wm_mask.nii.gz'))
			nib.save(gm_mask, os.path.join(outdir, 'gm_mask.nii.gz'))
			nib.save(csf_mask, os.path.join(outdir, 'csf_mask.nii.gz'))

		logging.info('Fractional anisotropy based tissue masks created.')

		# Flip sign of x-coordinate if affine determinant is positive and rotate to worldspace
		gtab = fsl_gtab_to_worldspace(gtab, data.affine)
		dti_vecs = fsl_vectors_to_worldspace(dti_vecs)
		logging.info('Rotation to worldspace finished')

		model = bd.ShoreMultiTissueResponseEstimator(gtab, args.order,
													 args.zeta, args.tau)
		fit = model.fit(data, dti_vecs, wm_mask, gm_mask, csf_mask,
						verbose=args.verbose, cpus=args.workers)
		fit.save(os.path.join(outdir, "response.npz"))
		logging.info('Response functions estimated and saved.')

	# We need this Meta object for saving later
	base_filename = os.path.join(indir, args.data).rstrip(".gz").rstrip(".nii")
	try:
		_, _, meta = dwmri.load(base_filename + '.nii.gz')
	except FileNotFoundError:
		try:
			_, _, meta = dwmri.load(base_filename + '.nii')
		except FileNotFoundError as e:
			raise FileNotFoundError(e)

	# Deconvolution if 'responseonly' is not set
	if not args.responseonly:
		out, wmout, gmout, csfout, residuals = fit.fodf(data, pos=args.constraint,
											 mask=dti_mask, kernel=args.kernel,
											 verbose=args.verbose,
											 cpus=args.workers)
		logging.info('Signal deconvolved with multiple tissue response functions.')
		residuals = nib.Nifti1Image(residuals, affine=data.affine)
		nib.save(residuals, os.path.join(outdir, 'residuals.nii.gz'))
		fields.save_tensor(os.path.join(outdir, "fodf.nrrd"), out,
						   mask=dti_mask.get_fdata(), meta=meta)

		logging.info('fODFs saved.')

		# Save volumes
		fields.save_scalar(os.path.join(outdir, "wmvolume.nrrd"),
						   wmout, meta)
		fields.save_scalar(os.path.join(outdir, "gmvolume.nrrd"),
						   gmout, meta)
		fields.save_scalar(os.path.join(outdir, "csfvolume.nrrd"),
						   csfout, meta)
		logging.info('Volume fractions saved.')

	logging.info('Success!')


if __name__ == "__main__":
	main()
