#!/usr/bin/python
# -*- coding: utf-8 -*-

import logging
import argparse
import os
import nrrd
import numpy as np
from bonndit.pmodels.model_avg import model_avg


def main():
	parser = argparse.ArgumentParser(
		description='This script performs a model selection or averaging according to '
					'Reducing Model Uncertainty in Crossing Fiber Tractography, Gruen et al. (2021)',
		formatter_class=argparse.ArgumentDefaultsHelpFormatter)

	parser.add_argument('-f', '--fodf', required=True,
						help='4D input file containing fODFs in masked higher-order tensor format (1+#fODF coefficients,x,y,z)', default=argparse.SUPPRESS)
	parser.add_argument('-i', '--infile', nargs='+', required=True,
						help='Three infiles containing low rank approx of rank 1 2 3', default=argparse.SUPPRESS)
	parser.add_argument('-t', '--type',
						help='selection or averaging', default='averaging')
	parser.add_argument('-s', '--save_prob',
						help='save probabilities', default=0)
	parser.add_argument('-o', '--outfile', required=True,
						help='5D output file with the approximation result (4,r,x,y,z)', default=argparse.SUPPRESS)
	parser.add_argument('-a', '--a', default=1,
						help='Parameter for Distribution')
	parser.add_argument('-b', '--b', default=20,
						help='Parameter for Distribution')
	parser.add_argument('-v', '--verbose', default=True)

	args = parser.parse_args()
	print(args)

	logging.basicConfig(filename=os.path.join("./", 'peak-modelling.log'),
						format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
						datefmt='%y-%m-%d %H:%M',
						level=logging.INFO,
						filemode='w')

	if args.verbose:
		# define a Handler which writes INFO messages or higher to the sys.stderr
		console = logging.StreamHandler()
		console.setLevel(logging.INFO)
		# set a format which is simpler for console use
		formatter = logging.Formatter(
			'%(name)-12s: %(levelname)-8s %(message)s')
		# tell the handler to use this format
		console.setFormatter(formatter)
		# add the handler to the root logger
		logging.getLogger('').addHandler(console)

	# Load fODF input file
	fodf, meta = nrrd.read(args.fodf)
	if fodf.shape[0] != 16:
		logging.error("fodf has to be 4th order tensor.")
		raise Exception()
	if len(fodf.shape) != 4:
		logging.error("fodf have to be in 3d space. Hence, fodf has to be 4d.")
		raise Exception()
	logging.info("4th order fODF loaded")
	# Read all low rank approximations and fit them to one array
	ranks = nrrd.read(args.infile[0]), nrrd.read(args.infile[1]), nrrd.read(args.infile[2])
	shape_ranks = [ranks[i][0].shape[1] - 1 for i in range(3)]

	if sum(shape_ranks) != 3 or max(shape_ranks) > 2:
		logging.error("The given low rank approximations have to be of rank 1, 2, 3.")
		raise Exception()
	logging.info("low rank approximations of rank 1,2,3 are loaded")
	rankk = [0,0,0]
	for i, sh in enumerate(shape_ranks):
		rankk[sh] = ranks[i][0]
	rank_1, rank_2, rank_3 = rankk

	low_rank = np.zeros((3, 4, 3) + rank_3.shape[2:])
	low_rank[0] = rank_3
	low_rank[1, :, :2] = rank_2
	low_rank[2, :, :1] = rank_1
	output = np.zeros(rank_3.shape)
	modelling = np.zeros((3,) + rank_3.shape[2:])
	logging.info(f"Calculating {args.type} model")
	model_avg(output, low_rank, fodf, args.type, modelling, np.float64(args.a), np.float64(args.b), args.verbose)
	output = output.reshape((4, 3) + fodf.shape[1:])
	# update meta file.
	if args.save_prob:
		# swap columns
		temp = np.copy(modelling[0])
		modelling[0] = np.copy(modelling[2])
		modelling[2] = temp
		nrrd.write(args.outfile.replace('.nrrd', 'probabilities.nrrd'), np.float32(modelling), meta)
		logging.info(f"Successfully created probabilities and saved to {args.outfile.replace('.nrrd', 'probabilities.nrrd')}")
	newmeta = {k: meta[k] for k in ['space', 'space origin']}
	newmeta['kinds'] = ['list', 'list', 'space', 'space', 'space']
	newmeta['space directions'] = np.vstack(([np.nan, np.nan, np.nan], meta['space directions']))
	nrrd.write(args.outfile, np.float32(output), newmeta)
	logging.info(f"Successfully created {args.type} model and saved to {args.outfile}")

if __name__ == "__main__":
	main()
