import tensorflow as tf
from keras.layers import Layer
import numpy as np

class OnnxConv3D(Layer):
    def __init__(self, filters, kernel_size, strides, padding, full_weights, use_bias, activation,
                 dilation_rate, name, groups, data_format="channels_first",  **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.use_bias = use_bias
        self.activation = activation
        self.dilation_rate = dilation_rate
        self._name = name
        self.groups = groups
        self.data_format = data_format
        self.channels_per_group = full_weights[0].shape[-2]
        self.convs = [tf.keras.layers.Conv3D(filters=int(filters/groups), kernel_size=kernel_size, strides=strides,
                                             padding=padding, use_bias=use_bias, activation=activation,
                                             weights=[full_weights[0][..., self.channels_per_group*i:self.channels_per_group*(i+1)],
                                                      full_weights[1][self.channels_per_group*i:self.channels_per_group*(i+1)]],
                                             dilation_rate=dilation_rate, data_format=data_format, name=name + "_" + str(i)) for i in
                      range(self.groups)]
        if data_format == "channels_first":
            self.concat = tf.keras.layers.Concatenate(axis=1)
        else:
            self.concat = tf.keras.layers.Concatenate(axis=-1)

    def call(self, inputs, **kwargs):
        if self.data_format == "channels_first":
            results = [self.convs[i](inputs[:, self.channels_per_group*i:self.channels_per_group*(i+1), ...]) for i in range(self.groups)]
        else:
            results = [self.convs[i](inputs[..., self.channels_per_group*i:self.channels_per_group*(i+1)]) for i in range(self.groups)]
        x = self.concat(results)
        return x

    def _get_full_weights(self):
        return [np.concatenate(self.weights[0::2], axis=-1), np.concatenate(self.weights[1::2], axis=-1)]

    def get_config(self):
        config = super().get_config()
        config.update({
        "filters": self.filters,
        "kernel_size": self.kernel_size,
        "strides": self.strides,
        "padding": self.padding,
        "use_bias": self.use_bias,
        "activation": self.activation,
        "dilation_rate": self.dilation_rate,
        "full_weights": self._get_full_weights(),
        "name": self._name,
        "groups": self.groups,
        "data_format": self.data_format,
        })
        return config