#!/usr/bin/python
# -*- coding: utf-8 -*-

import logging
import os
import argparse
import numpy as np
from bonndit.directions.fodfapprox import approx_all_spherical
import nrrd

def main():
    parser = argparse.ArgumentParser(
        description='This script performs a low-rank approximation of fODFs '
                    'that are given in a higher-order tensor format, as described in '
                    '"Estimating Crossing Fibers: A Tensor Decomposition Approach" '
                    'by Schultz/Seidel (2008).', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-i', '--infile', nargs='+',
                        help='4D input file containing fODFs in masked higher-order tensor format (1+#fODF coefficients,x,y,z)', default=argparse.SUPPRESS)


    parser.add_argument('-o', '--outfile',
                        help='5D output file with the approximation result (4,r,x,y,z)', default=argparse.SUPPRESS)
    parser.add_argument('--init', help='initialize')

    parser.add_argument('-r', help='rank', default=3)
    parser.add_argument('-v', '--verbose', default=True)
    args = parser.parse_args()


    logging.basicConfig(filename=os.path.join("./", 'low-rank-k-approx.log'),
                        format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
                        datefmt='%y-%m-%d %H:%M',
                        level=logging.INFO,
                        filemode='w')

    if args.verbose:
        # define a Handler which writes INFO messages or higher to the sys.stderr
        console = logging.StreamHandler()
        console.setLevel(logging.INFO)
        # set a format which is simpler for console use
        formatter = logging.Formatter(
            '%(name)-12s: %(levelname)-8s %(message)s')
        # tell the handler to use this format
        console.setFormatter(formatter)
        # add the handler to the root logger
        logging.getLogger('').addHandler(console)


    # Load fODF input file
    fodfs = [0]*len(args.infile)
    for i, f in enumerate(args.infile):
        fodfs[i], meta = nrrd.read(f)
    fodfs = np.mean(fodfs, axis=0)
    print(fodfs.shape)
    if fodfs.shape[0] != 16:
        logging.error("fodf has to be 4th order tensor.")
        raise Exception()
    if len(fodfs.shape) != 4:
        logging.error("fodf have to be in 3d space. Hence, fodf has to be 4d.")
        raise Exception()

    logging.info("fODF is loaded and has the correct format")
    logging.info(f"Creating low-rank {args.r} approximation")

    data = fodfs.reshape((fodfs.shape[0], -1))
    NUM = data.shape[1]
    if args.init:

        output, _ = nrrd.read(args.init)
        output = np.float64(output.reshape((4, int(args.r),NUM)))
    else:
        output = np.zeros((4, int(args.r), NUM))
    # check format:
    approx_all_spherical(output, np.float64(fodfs), int(0), np.float32(0), int(args.r), int(args.init is not None), args.verbose)
    logging.info(f"Low-rank {args.r} approximation is calculated")


    output = output.reshape((4, int(args.r)) + fodfs.shape[1:])
    newmeta = {k: meta[k] for k in ['space', 'space origin']}
    newmeta['kinds'] = ['list', 'list', 'space', 'space', 'space']
    newmeta['space directions'] = np.vstack(([np.nan, np.nan, np.nan], meta['space directions']))
    nrrd.write(args.outfile, np.float32(output), newmeta)
    logging.info(f"Low-rank {args.r} approximation is saved to {args.outfile}")


if __name__=="__main__":
    main()
