#!/usr/bin/python3
# -*- coding: utf-8 -*-

# cython: profile=True
import cython
import random
import logging

import pandas as pd

from bonndit.tracking.tracking_prob import tracking_all
import argparse
import nrrd
import nibabel as nib
import os
import numpy as np
from plyfile import PlyElement, PlyData
import bonndit
from bonndit.utils.io import fsl_gtab_to_worldspace, fsl_flip_sign
import time

path_tree = []


def load(key, args, filename, path, logging, func):
	if key in args:
		if os.path.exists(getattr(args, key)):
			try:
				data = func(getattr(args, key))
			except:
				logging.error(
					'Something went wrong for the file {}. Maybe the data is corrupted or has wrong format.'.format(
						key))
				raise Exception()
		else:
			logging.error('-{} flag is set incorrectly'.format(key))
			raise Exception()
	else:
		if os.path.exists(os.path.join(path, filename)):
			try:
				data = func(os.path.join(path, filename))
			except:
				logging.error(
					'Something went wrong for the file {}. Maybe the data is corrupted or has wrong format.'.format(
						filename))
				raise Exception()
		else:
			logging.error(
				'Neither {} is set as argument nor the {} is present in the input folder. Please specify'.format(key,
																												 filename))
			raise Exception()
	return data


def main():
	parser = argparse.ArgumentParser(
		description='This script performs tracking along a multi vector field as described in '
					'Reducing Model Uncertainty in Crossing Fiber Tractography, Gruen et al. (2021)',
		formatter_class=argparse.RawTextHelpFormatter)
	# All files needed
	parser.add_argument('-i', help='Inputfolder should contain: \n'
									'	- lowrank.nrrd \n'
									'		Multidirectionfield, where the first dimension defines the length and the \n'
									'		unit direction of the vector, second dimension defines different directions \n'
									'		and remaining dimensions diffine the coordinate. \n '
									'		If the file is named differently, use the --infile argument \n'
									'	- wmvolume.nrrd \n'
									'		The white matter mask, which is an output of mtdeconv. \n'
									'		If the File is named differently, use the --wmmask argument \n'
									'	- seedpoint.pts \n '
									'		The seed point file in world coordinates. First 3 dimensions of row give \n'
									'		world coordinates. Additionally a initial direction can be set by appending \n'
									'		3 columns to each row denoting the direction in (x,y,z) space. \n '
									'		If the file is named differently, use the --seedpoint argument. \n\n'
									'If the -ukf flag is set, the input folder should also contain: \n '
									'	- bvals \n'
									'		A text file which contains the bvals for each gradient direction. \n '
									'		If the file is namend differenty, use the --ukf_bvals argument \n'
									'	- bvecs \n'
									'		A text file which contains all gradient directions in the format Ax3 \n '
									'		If the file is named differently, use the --ukf_bvecs argument \n'
									'	- data.nrrd \n'
									'		The file with the data. If the ukfmethod flag is set to \n '
									'			- MultiTensor it should be the raw data. \n'
									'			- LowRank it should be the fodf.nrrd output from mtdeconv \n'
									'		If the file is named differently, use the --ukf_data argument. \n'
									'	- baseline.nrrd \n'
									'		File with b0 measurements.'
									'		If the file is named differently, use the --ukf_baseline argument'
									'If the -disk flag is set and we want to append to a file, the inputfolder should contain \n'
									'	- output.txt \n'
									'		A textfile with the streamlines generated so far. \n'
									'		If the file is named differently, use the --disk_append argument.',
						default=argparse.SUPPRESS)

	parser.add_argument('--infile',
						help='5D (4,3,x,y,z) Multivectorfield, where the first dimension gives the length \n'
							 'and the direction of the vector, the second dimension denotes different directions',
						default=argparse.SUPPRESS)
	parser.add_argument('--wmvolume', help='WM Mask - output of mtdeconv', default=argparse.SUPPRESS)
	parser.add_argument('--act', help='5tt output of 5ttgen. Will perform act if supplied.', default=argparse.SUPPRESS)
	parser.add_argument('--seedpoints',
						help='Seedspointfile: Each row denotes a seed point, where the first  3 columns give the \n'
							 ' seed point in (x,y,z). Further 3 additional columns can specified to define a initial \n'
							 'direction. Columns should be seperated by whitespace.',
						default=argparse.SUPPRESS)
	# General Tracking Parameters
	parser.add_argument('--wmmin', help='Minimum WM density before tracking stops', default=0.15)
	parser.add_argument('--sw_save',
						help="Only each x step is saved. Default everystep is saved. Reduces memoryconsumption greatly",
						default=1)
	parser.add_argument('-sw', help='Stepwidth for Euler integration', default=0.9)
	parser.add_argument('-o', help='Filename for output file in ply format.', required=True, default=argparse.SUPPRESS)
	parser.add_argument('-mtlength', help='Maximum track steps.', default=300)
	parser.add_argument('--samples', help='Samples per seed.', default=1)
	parser.add_argument('-max_angle', help='Max angle over the last 30 mm of the streamline', default=130)
	parser.add_argument('--var', help='Variance for probabilistic direction selection..', default=6)
	parser.add_argument('--exp', help='Expectation for probabilistic direction selection.', default=3)

	parser.add_argument('--interpolation', '--interpolation',
						help='decide between FACT interpolation and Trilinear interpolation.',
						default='TrilinearFODF')
	parser.add_argument('--sigma_1',
						help='Only useful if interpolation is set to TrilinearFODF and dist>0. Controls sigma1 for low-rank approx',
						default='0')
	parser.add_argument('--data',
						help='Only useful if interpolation is set to TrilinearFODF and dist>0. Controls sigma1 for low-rank approx',
						default=argparse.SUPPRESS)
	parser.add_argument('--sigma_2',
						help='Only useful if interpolation is set to TrilinearFODF and dist>0. Controls sigma2 for low-rank approx.',
						default='0')
	parser.add_argument('--dist',
						help='Only useful if interpolation is set to TrilinearFODF. Radius of points to include',
						default=0)
	parser.add_argument('--rank', help='Only useful if interpolation is set to TrilinearFODF. Rank of low-rank approx.',
						default='3')

	parser.add_argument('--integration', '--integration',
						help='Decide between Euler integration and Euler integration. '
						, default='TrilinearFODF')
	parser.add_argument('-prob', '--prob',
						help='Decide between Laplacian, Gaussian, Scalar, ScalarNew, Deterministic and Deterministic2 '
						, default='ScalarNew')

	# Arguments for saving each streamline on disk
	parser.add_argument('--disk', help='Write streamlines to file instead of using ram', action='store_false')
	parser.add_argument('--disk_file', help='Name of disk file. If not set a random filename is chosen',
						default='/tmp/.stream' + ''.join(
							random.choice([chr(i) for i in range(ord('a'), ord('z'))]) for _ in
							range(10)))
	parser.add_argument('--disk_delete',
						help='Delete file after finish. Otherwise further Streamlines can be appended if \n'
							 'more streamlines are needed.', action='store_false')

	# Arguments for ukf
	parser.add_argument('--ukf',
						help='The following arguments are just important if the --ukf flag is set to MultiTensor or LowRank',
						default=argparse.SUPPRESS)
	parser.add_argument('--ukf_data', help='File containing the raw data for ukf.', default=argparse.SUPPRESS)
	parser.add_argument('--ukf_bvals', help='File containg the bvals for each gradient direction',
						default=argparse.SUPPRESS)
	parser.add_argument('--ukf_bvecs', help='File containg the bvecs ', default=argparse.SUPPRESS)
	parser.add_argument('--ukf_baseline', help='File containg the baseline', default=argparse.SUPPRESS)
	parser.add_argument('--ukf_fodf_order', help='order of fODF. Only 4 and 8 are supported. Default is 4', default=4)
	parser.add_argument('--ukf_pnoise', help='Process noise', nargs='+', default=argparse.SUPPRESS)
	parser.add_argument('--ukf_mnoise', help='Measurement noise', nargs='+', default=argparse.SUPPRESS)
	# subsubparser = parser_ukf.add_subparsers(help = 'Decide between low-rank and Multitensor model')

	# Arguments for Watson fitting
	parser.add_argument('--kappa_range', help='For Watson fitting: range of initial kappa values to randomly sample from.', default='39.9,40')
	parser.add_argument('--max_sampling_angle', help='For Watson fitting: max angle to watson peak to randomly sample from.', default=90)
	parser.add_argument('--max_kappa', help='For Watson fitting: max kappa value to sample from, higher values get clipped', default=80)
	parser.add_argument('--min_kappa', help='For Watson fitting: min kappa value to sample from, if lower the tracking is stopped', default=1)
	parser.add_argument('--watson_prob_dir_selection', help='If added, in the Watson direction selection the distribution is chosen as in ScalarNew followed by the sampling.', action='store_true')

	parser.add_argument('-v', '--verbose', default=True)
	parser.add_argument('-r', '--exclusion', default=argparse.SUPPRESS)
	parser.add_argument('--min_len', default=0)
	parser.add_argument('-ri', '--inclusion', default=argparse.SUPPRESS)
	parser.add_argument('-fsl', default="")
	parser.add_argument('--features', nargs='+', default="")
	args = parser.parse_args()
	logging.basicConfig(filename=os.path.join("./", 'prob-tracking.log'),
						format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
						datefmt='%y-%m-%d %H:%M',
						level=logging.INFO,
						filemode='w')

	if args.verbose:
		# define a Handler which writes INFO messages or higher to the sys.stderr
		console = logging.StreamHandler()
		console.setLevel(logging.INFO)
		# set a format which is simpler for console use
		formatter = logging.Formatter(
			'%(name)-12s: %(levelname)-8s %(message)s')
		# tell the handler to use this format
		console.setFormatter(formatter)
		# add the handler to the root logger
		logging.getLogger('').addHandler(console)
	# Load all data needed
	if 'i' not in args:
		logging.error('-i has to be set. ')
		raise Exception()
	if not os.path.exists(args.i):
		logging.error('Path to input folder is not correct.')
		raise Exception()
	# Load data
	vector_field, meta, = load('infile', args, 'lowrank.nrrd', args.i, logging, nrrd.read)
	trafo_data = np.zeros((4,4))
	trafo_data[:3,:3] = meta['space directions'][-3:]
	trafo_data[:3,3] = meta['space origin'][-3:]
	trafo_data[3,3] = 1
	# print(meta)
	if vector_field.shape[0] != 4:
		logging.error("Wrong dimension on first axis. Has to contain 4 values.")
		raise Exception()
	# if vector_field.shape[1] != 3:
	#	logging.error("Wrong dimension on second axis. Has to contain 3 values.")
	#	raise Exception()
	if len(vector_field.shape) != 5:
		logging.error("The input multivector field has to have 5 dimensions.")
		raise Exception()
	logging.info("Multivectorfield loaded")
	if 'wmvolume' not in args:
		wm_mask, trafo = load('wmvolume', args, 'wmvolume.nrrd', args.i, logging, nrrd.read)
		trafo_mask = np.zeros((4, 4))
		trafo_mask[:3, :3] = trafo['space directions']
		trafo_mask[:3, 3] = trafo['space origin']
		trafo_mask[3,3] = 1
	elif args.wmvolume.endswith('nrrd'):
		wm_mask, trafo = load('wmvolume', args, 'wmvolume.nrrd', args.i, logging, nrrd.read)
		trafo_mask = np.zeros((4,4))
		trafo_mask[:3,:3] = trafo['space directions']
		trafo_mask[:3,3] = trafo['space origin']
		trafo_mask[3,3] = 1
	elif args.wmvolume.endswith('nii.gz'):
		wm_mask = load('wmvolume', args, 'wmvolume.nrrd', args.i, logging, nib.load)
		trafo_mask = wm_mask.affine
		wm_mask = wm_mask.get_fdata()

	else:
		logging.error("WMvolume has wrong format. Only nii.gz and nrrd are accepted.")
		raise Exception()
	if vector_field.shape[-3:] != wm_mask.shape:
		logging.error("Vectorfield (x,y,z) and wm mask have to have same dimensions.")
		raise Exception()
	logging.info("WM Mask loaded")
	if args.seedpoints == "test_CC":
		dir_path = os.path.dirname(bonndit.__file__)
		seeds = np.loadtxt(os.path.join(dir_path, "data/CC.pts"))
	else:
		seeds = load('seedpoints', args, 'seedpoints.pts', args.i, logging, np.loadtxt)
	logging.info("Seedfile loaded")
	if 'ukf' in args:
		logging.info("UKF Flag is set. Continue with loading UKF data")
		logging.info("raw data loaded")
		if args.ukf == 'MultiTensor':
			data, meta2 = load('ukf_data', args, 'data.nrrd', args.i, logging, nrrd.read)
			data = np.moveaxis(data, 0, -1)
			bvals = load('ukf_bvals', args, 'bvals', args.i, logging, np.loadtxt)
			logging.info("bvals loaded")
			bvecs = load('ukf_bvecs', args, 'bvecs', args.i, logging, np.loadtxt)
			# Load bvals and bvecs and initialize a GradientTable object
			bvecs = fsl_flip_sign(bvecs.T, meta2['space directions'][-3:,-3:]).T
			logging.info("bvecs loaded")
			baseline, _ = load('ukf_baseline', args, 'baseline.nrrd', args.i, logging, nrrd.read)

			logging.info("b0 loaded")
		elif args.ukf in ['LowRank', 'Watson', 'Bingham']:
			data, _ = load('ukf_data', args, 'fodf.nrrd', args.i, logging, nrrd.read)
			data = np.moveaxis(data, 0, -1)
			data = data[:, :, :, 1:]
			bvals = None
			bvecs = None
			baseline = None
		# data = ""
		else:
			logging.error('Only LowRank and MultiTensor are possible options for the ukf flag.')
	if args.disk:
		logging.info(
			'Streamlines are temporarily written to {} which is {} after creating streamines'.format(args.disk_file,
																									 'kept' if not args.disk_delete else 'deleted'))
	if os.path.exists(args.fsl):
		trafo_fsl = np.loadtxt(args.fsl)
	else:
		trafo_fsl = np.identity(4)

	logging.info("Start fiber tracking")
	if "act" in args:
		act = load('act', args, 'act.nii.gz', args.i, logging, nib.load)
		trafo_mask = act.affine
	else:
		act = None
	tracking_parameters = {
		'prob': args.prob,
		'trafo_data': trafo_data,
		'min_len': int(args.min_len),
		'space directions': meta['space directions'][-3:],
		'space origin': meta['space origin'],
		'trafo_fsl': trafo_fsl,
		'ukf': args.ukf if 'ukf' in args else False,
		'variance': float(args.var),
		'expectation': float(args.exp),
		'trafo_mask': trafo_mask,
		'stepsize': float(args.sw),
		'max_track_length': int(args.mtlength),
		'integration': args.integration,
		'interpolation': args.interpolation,
		'max_angle': float(args.max_angle),
		'samples': int(args.samples),
		'wmmin': float(args.wmmin),
		'wm_mask': np.float64(wm_mask),
		'act': act.get_fdata() if act else None,
		'verbose': args.verbose,
		'sw_save': int(args.sw_save)
	}
	postprocessing = {
		'inclusion': np.loadtxt(args.inclusion) if "inclusion" in args else "",
		'exclusion': np.loadtxt(args.exclusion) if "exclusion" in args else "",
	}
	ukf_parameter = {}
	if args.interpolation == 'TrilinearFODF':
		trilinear_parameters = {
			'data': load('data', args, 'fodf.nrrd', args.i, logging, nrrd.read)[0],
			'trafo_data': trafo_data,
			'sigma_1': float(args.sigma_1),
			'sigma_2': float(args.sigma_2),
			'trafo': meta['space directions'][-3:],
			'r': float(args.dist),
			'rank': int(args.rank)
		}
	else:
		trilinear_parameters = {}

	if "ukf" in args:

		ukf_parameter = {
			'trafo_data': trafo_data,
			'data': np.array(data, dtype=np.float64),
			'dim_model': vector_field.shape[1] * 5 if args.ukf == "MultiTensor" else vector_field.shape[1]*4,
			'model': 'fodf' if args.ukf == "LowRank" else 'MultiTensor',
			'gradients': bvecs.astype(np.float64).T if bvecs is not None else "",
			'baseline': baseline,
			'order': int(args.ukf_fodf_order),
			'b': bvals,
			'b0': 3000,
			'process noise': np.array(args.ukf_pnoise, dtype=np.float64) if 'ukf_pnoise' in args else "",
			'measurement noise': np.array(args.ukf_mnoise, dtype=np.float64) if 'ukf_mnoise' in args else ""
		}
		if args.ukf == "MultiTensor":
			if ukf_parameter['data'].shape[-1] != ukf_parameter['gradients'].shape[0]:
				if ukf_parameter['data'].shape[-1] != ukf_parameter['gradients'].T.shape[1]:
					logging.error('Data has to have the same dimension in the last column of bvecs')
					raise Exception()
				ukf_parameter['gradients'] = ukf_parameter['gradients'].T
			if ukf_parameter['gradients'].shape[0] != ukf_parameter['b'].shape[0]:
				logging.error('bvals have to have the same dimension in the last column of bvecs')
				raise Exception()
	saving = {
		'features': {
			'chosen_angle': args.features.index('angle') if 'angle' in args.features else -2,
			'seedpoint': args.features.index('seedpoint') if 'seedpoint' in args.features else -2,
			'prob_chosen': args.features.index('prob_chosen') if 'prob_chosen' in args.features else -2,
			'fa': args.features.index('fa') if 'fa' in args.features else -2,
			'prob_others_0': args.features.index('prob_others') if 'prob_others' in args.features else -2,
			'prob_others_1': args.features.index('prob_others') + 1 if 'prob_others' in args.features else -2,
			'prob_others_2': args.features.index('prob_others') + 2 if 'prob_others' in args.features else -2,
			'len': len(args.features) + 2 if 'prob_others' in args.features else len(args.features)
		},
		'file': args.disk_file if args.disk else "",
	}

	paths, paths_len = tracking_all(np.float64(vector_field), np.float64(wm_mask), np.float64(seeds),
									tracking_parameters, postprocessing, ukf_parameter,trilinear_parameters,  logging, saving)
	del tracking_parameters, postprocessing, ukf_parameter, trilinear_parameters, vector_field, wm_mask, seeds, act
	if 'disk' in args:
		paths_len = pd.read_csv(args.disk_file + 'len', header=None, dtype=int).values.flatten()
		paths = pd.read_csv(args.disk_file, header=None, delimiter=" ").values

		if args.disk_delete:
			os.remove(args.disk_file)
			os.remove(args.disk_file + 'len')
	paths_len = np.cumsum(paths_len)
	features = [0]*saving['features']['len']
	for x in saving['features']:
		if saving['features'][x] >= 0 and x != 'len':
			features[saving['features'][x]] = x
	if args.o.endswith('ply'):
		paths = [tuple(x) for x in paths]
		tracks = PlyElement.describe(np.array(paths, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')] + [(x, 'f4') for x in features]), 'vertices',
									 comments=[])
		endindex = PlyElement.describe(np.array(paths_len, dtype=[('endindex', 'i4')]), 'fiber')
		PlyData([tracks, endindex]).write(args.o)

	elif args.o.endswith('tck'):
		#paths_len = np.array(paths_len, dtype=int)
		paths = np.array(paths)[:,:3]
		streamlines = nib.streamlines.ArraySequence(np.split(paths, paths_len))
		mytractogram = nib.streamlines.tractogram.Tractogram(streamlines, affine_to_rasmm=np.identity(4))
		tractogram = nib.streamlines.tck.TckFile(mytractogram)
		tractogram.save(args.o)
	logging.info(f"Output file has been written to {args.o}")

if __name__ == "__main__":
	main()
