# -*- coding: utf-8 -*-
import numpy as np
from numpy import concatenate as join
import awkward as ak
from awkward import unflatten as build

from dewloosh.math.arraysetops import unique2d


class TopologyArray(ak.Array):

    def __init__(self, *args, behavior=None, with_name=None,
                 check_valid=False, cache=None, kernels=None,
                 cuts=None, **kwargs):
        topo = list(filter(lambda arg: isinstance(arg, np.ndarray), args))
        # if there are no args, topo is an empty list now
        if len(topo) > 0:
            widths = list(map(lambda topo: topo.shape[1], topo))
            widths = np.array(widths, dtype=int)
            cN, cE = 0, 0
            for i in range(len(topo)):
                dE = topo[i].shape[0]
                cE += dE
                cN += dE * topo[i].shape[1]
            topo1d = np.zeros(cN, dtype=int)
            cuts = np.zeros(cE, dtype=int)
            cN, cE = 0, 0
            for i in range(len(topo)):
                dE = topo[i].shape[0]
                dN = dE * topo[i].shape[1]
                topo1d[cN:cN+dN] = topo[i].flatten()
                cN += dN
                cuts[cE:cE+dE] = np.full(dE, widths[i])
                cE += dE
            data = get_data(topo1d, cuts=cuts)
            super().__init__(data, behavior=behavior, with_name=with_name,
                             check_valid=check_valid, cache=cache, kernels=kernels)
        else:
            super().__init__(*args, **kwargs)

    def __array_function__(self, func, types, args, kwargs):
        if func == np.unique:
            return unique2d(*args, **kwargs)
        return super().__array_function__(func, types, args, kwargs)

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        return super().__array_ufunc__(ufunc, method, *inputs, **kwargs)


def shape(arr): return arr.shape[:2]
def cut(shp): return np.full(shp[0], shp[1], dtype=np.int64)
def flatten(arr): return arr.flatten()


def get_data(data, cuts=None):
    if isinstance(data, np.ndarray):
        nD = len(data.shape)
        assert nD <= 2, "Only 2 dimensional arrays are supported!"
        if nD == 1:
            assert isinstance(cuts, np.ndarray)
            data = build(data, cuts)
    elif isinstance(data, list):
        assert all(map(lambda arr: len(arr.shape) == 2, data)), \
            "Only 2 dimensional arrays are supported!"
        # NOTE - implementaion 1
        # > Through the back door, but this is probably the cleanest solution of all.
        # > It only requires to create one python list, without further operations on it.
        # NOTE This line is one of the most unreadable things I've ever done.
        data = build(join(list(map(flatten, data))),
                     join(list(map(cut, map(shape, data)))))
        # NOTE - implementaion 2
        #from operator import add
        #from functools import reduce
        # > This also works, but requires data to jump back and forth just to
        # > have a merged list of lists. It also requires to add nested python lists,
        # > which is probably not the quickest operation in the computational world.
        #data = ak.from_iter(reduce(add, map(lambda arr : ak.Array(arr).to_list(), data)))
        # NOTE - implementaion 3
        # > This also works at creation, but fails later at some operation due to
        # > the specific layout generated by ak.concatenate
        #data = ak.concatenate(list(map(lambda arr : ak.Array(arr), data)))
    return data


if __name__ == '__main__':

    topo1 = np.array([[0, 1], [1, 2], [2, 3]])
    topo2 = np.array([[0, 1, 4], [1, 2, 5], [2, 3, 6]])
    topo = TopologyArray(topo1, topo2)
    print(topo[1, 1])
    print(topo[4, 2])
