#!/usr/bin/python
# -*- coding: utf-8 -*-

import numpy as np
from bonndit.utils.esh import sym_to_esh_matrix, esh_to_sym_matrix, get_order
import nibabel as nib
import argparse
from collections import OrderedDict
import nrrd

def create_mapping(imag, order):
    if imag:
        l = 1
        z = 0
        mapping_coeffs = []
        for i in range(0, order + 1, 2):
            for j in range(l):
                mapping_coeffs.append(z + l - j - 1)
            z = len(mapping_coeffs)
            l += 4
    else:
        mapping_coeffs = [i for i in range(15)]
    return tuple(mapping_coeffs)


def mrtrix2bonndit(input, output, imag):
    try:
        img = nib.load(input)
    except:
        raise Exception('%s could not be loaded' % input)
    assert np.linalg.det(img.affine) > 0, 'Affine matrix is not positive definite, e.g. the orientation is not RAS.'
    fodfs = img.get_fdata()
    order = get_order(fodfs[0,0,0])
    conversion_matrix = esh_to_sym_matrix(order)
    fodfs = fodfs[..., create_mapping(imag, order)]
    fodfs = (conversion_matrix @ fodfs[..., np.newaxis])[...,0]
    fodfs = np.moveaxis(fodfs, 3,0)
    fodfs = np.vstack((np.ones((1, *fodfs.shape[1:])), fodfs))

    meta = OrderedDict((
             ('dimension', 4),
             ('space', 'right-anterior-superior'),
             ('sizes', np.array(fodfs.shape)),
             ('space directions',np.vstack(([np.nan, np.nan, np.nan], img.affine[:3,:3]))),
             ('kinds', ['???', 'space', 'space', 'space']),
             ('space origin', img.affine[:3,3])))

    nrrd.write(output, fodfs, meta)

def bonndit2mrtrix(input, output, imag):
    try:
        fodfs, meta = nrrd.read(input)
    except:
        raise Exception('%s could not be loaded' % input)
    order = get_order(fodfs[1:,0,0,0])
    conversion_matrix = sym_to_esh_matrix(order)
    fodfs = np.moveaxis(fodfs, 0,3 )

    fodfs = (conversion_matrix @ fodfs[..., 1:, np.newaxis])[..., create_mapping(imag, order),0]
    affine = np.eye(4)
    affine[:3,:3] = meta['space directions'][1:]
    affine[:3,3] = meta['space origin']
    img = nib.Nifti1Image(fodfs, affine)
    nib.save(img, output)


def main():
    parser = argparse.ArgumentParser(
        description='Helper script to convert between spherical harmonics (tournier 2007 basis - readable by dipy and mrtrix) and \
                     tensor basis used by bonndit. Per default m indexing is inverted. ',)

    parser.add_argument('-i', help='Input filename')
    parser.add_argument('-o', help='Output filename')
    parser.add_argument('--midx', help='Invert m indexing', action=argparse.BooleanOptionalAction, default=True)
    args = parser.parse_args()
    if args.i.endswith('nii.gz'):
        mrtrix2bonndit(args.i, args.o, args.midx)
    else:
        bonndit2mrtrix(args.i, args.o, args.midx)

if __name__=="__main__":
    main()
