#!/usr/bin/python
# -*- coding: utf-8 -*-

import argparse
import numpy as np
from scipy.ndimage import gaussian_filter
from bonndit.tracking.ItoW import Trafo
from bonndit.filter.filter import plysplitter, streamline_filter, intersection_dict, filter_mask, streamline_itow
import regex as re
import nrrd
from plyfile import PlyElement, PlyData
import logging
import os

def is_integer(n):
    try:
        float(n)
    except ValueError:
        return False
    else:
        return float(n).is_integer


def main():
	parser = argparse.ArgumentParser(
		description='Given a bundle various checkings on each fiber can be done. All filters are only applied if they '
					'are set. Filteroptions are exclusion slices, exclusion regions, min length, min density.',
		formatter_class=argparse.ArgumentDefaultsHelpFormatter)

	parser.add_argument('-i', help='Infile should contain a bundle in ply format where the fourth column contains '
								   'information about the seedpoint.', default=argparse.SUPPRESS, required=True)
	parser.add_argument('-m', help='Metadata from fodf file.', default=argparse.SUPPRESS, required=True)
	parser.add_argument('-o', help='Outfile', default=argparse.SUPPRESS, required=True)
	parser.add_argument('-f', help='Outfile mask', default=argparse.SUPPRESS)
	parser.add_argument('--mask', help='Min Streamline density. Starting from the seed the streamline is terminated '
									   'at the first intersection with a low density region', default=5)
	parser.add_argument('--exclusion', help='Exclusion slice. Can be set via x<10. Multiple regions can be seperated '
											'by whitespace.', default=argparse.SUPPRESS)
	parser.add_argument('--exclusionc',
						help='Exclusion region. Can be given via 5<x<10,4<y<20,3<z<30. Multiple regions can be seperated by whitespace.', default=argparse.SUPPRESS)
	parser.add_argument('--minlen', help='Min length of a streamline in terms of steps.', default=argparse.SUPPRESS)
	parser.add_argument('-v', '--verbose', default=True)

	args = parser.parse_args()

	logging.basicConfig(filename=os.path.join("./", 'peak-modelling.log'),
						format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
						datefmt='%y-%m-%d %H:%M',
						level=logging.INFO,
						filemode='w')
	if args.verbose:
		# define a Handler which writes INFO messages or higher to the sys.stderr
		console = logging.StreamHandler()
		console.setLevel(logging.INFO)
		# set a format which is simpler for console use
		formatter = logging.Formatter(
			'%(name)-12s: %(levelname)-8s %(message)s')
		# tell the handler to use this format
		console.setFormatter(formatter)
		# add the handler to the root logger
		logging.getLogger('').addHandler(console)

	exclusion = []
	if "exclusion" in args:
		exclusion = re.split(' ', args.exclusion)
		exclusion = [re.split('(<|>)', x) for x in exclusion]
		for ex in exclusion:
			if not ex[0] in ['x', 'y', 'z']:
				raise Exception('Only x,y,z is allowed')
			if not ex[1] in ['<', '>']:
				raise Exception('Only >,< is allowed')
			if not is_integer(ex[2]):
				raise Exception('Only integers are allowed')
			ex[2] = int(ex[2])
	exclusionc = []

	if "exclusionc" in args:
		exclusionc = re.split(' ', args.exclusionc)
		exclusionc = [re.split(',', x) for x in exclusionc]
		exclusionc = [[re.split('(<|>)', x) for x in exclusionc[index]] for index in range(len(exclusionc))]

	with open(args.i, 'rb') as f:
		try:
			plydata = PlyData.read(f)
		except:
			logging.error("Plyfile seems to be corrupted.")
			raise Exception()
		num_verts = plydata['vertices'].count
		num_fiber = plydata['fiber'].count
		vertices = np.zeros(shape=[num_verts, 4], dtype=np.float64)
		endindex = np.zeros(shape=[num_fiber], dtype=np.float64)
		vertices[:, 0] = plydata['vertices'].data['x']
		vertices[:, 1] = plydata['vertices'].data['y']
		vertices[:, 2] = plydata['vertices'].data['z']
		try:
			vertices[:, 3] = plydata['vertices'].data['seedpoint']
		except:
			logging.error("Seedpoint is not in data")
			raise Exception()
		endindex[:] = plydata['fiber'].data['endindex']
	logging.info("Loaded Streamlines")

	_, meta = nrrd.read(args.m)
	trafo = Trafo(np.float64(meta['space directions'][1:]), np.float64(meta['space origin']))
	logging.info("Created ItoW transformation")

	# split streamlines according endindex
	logging.info("Split streamlines")
	streamlines_test = plysplitter(vertices, endindex, args.verbose)
	# Read meta data and convert the streamlines into subsets in index space
	delete = []
	logging.info("Apply exclusion region filters")
	delete = streamline_filter(streamlines_test, exclusionc, exclusion, trafo, args.verbose)

	streamlines_test = np.array(streamlines_test, dtype=object)[delete]
	endindexes = [streamlines_test[i].shape[0] for i in range(len(streamlines_test))]

	logging.info("Exclude short streamlines")
	if args.minlen:
		logging.info("Exclude short streamlines")
		streamlines_test = streamlines_test[np.array(endindexes) > int(args.minlen)]

	# First mask
	mask = np.zeros(meta['sizes'][1:])
	mask2 = np.zeros(meta['sizes'][1:])
	logging.info("Create density mask")
	mask = np.array(intersection_dict(mask, streamlines_test, args.verbose))

	# creat ratio between refernce patient and this one

	mask2[mask < float(args.mask)] = 0
	mask2[mask >= float(args.mask)] = 1
	meta['space directions'] = meta['space directions'][1:]
	meta['kinds'] = meta['kinds'][1:]
	if args.f:
		logging.info("Save density mask")
		nrrd.write(args.f, mask2, meta)
	mask[mask < float(args.mask)] = 0
	mask = gaussian_filter(mask, sigma=1)
	logging.info("Cut off streamlines at the first intersection with boundaries of density mask")
	streamlines_test = filter_mask(streamlines_test, mask, args.verbose)
	endindezes = [True if streamlines_test[i].shape[0] > 1 else False for i in range(len(streamlines_test))]
	streamlines_test = np.array(streamlines_test, dtype=object)[endindezes]

	streamlines_test = streamline_itow(streamlines_test, trafo)

	# filter out short streamlines_test:
	endindexes = [streamlines_test[i].shape[0] for i in range(len(streamlines_test))]

	streamlines_test = np.array(streamlines_test, dtype=object)[np.array(endindexes) > 50]
	streamlines_test = np.concatenate(streamlines_test)
	endindexes = [x for x in endindexes if x > 50]
	endindexes = [sum(endindexes[:i + 1]) for i in range(len(endindexes))]

	# Save the file
	tracks = PlyElement.describe(
		np.array([tuple(x) for x in streamlines_test], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('start', 'f4')]),
		'vertices',
		comments=[])
	endindex = PlyElement.describe(np.array(list(endindexes), dtype=[('endindex', 'i4')]), 'fiber')
	PlyData([tracks, endindex]).write(args.o)
	logging.info(f"Streamlines have been saved to {args.o}")


if __name__=="__main__":
	main()


