import random
random.seed(0)
import numpy as np
np.random.seed(0)
import tensorflow as tf
import onnx_graphsurgeon as gs
from onnx2tf.utils.common_functions import (
    get_constant_or_variable,
    print_node_info,
    inverted_operation_enable_disable,
    make_tf_node_info,
)


@print_node_info
@inverted_operation_enable_disable
def make_node(
    *,
    graph_node: gs.Node,
    tf_layers_dict: dict,
    **kwargs: dict,
):
    """Erf

    Parameters
    ----------
    graph_node: gs.Node
        graph_surgeon Node

    tf_layers_dict: dict
        optype, shape, dtype, tensorflow graph
    """
    before_op_output_shape_trans_1 = \
        tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
    before_op_output_shape_trans = \
        before_op_output_shape_trans_1

    graph_node_input = get_constant_or_variable(
        graph_node.inputs[0],
        before_op_output_shape_trans,
    )
    graph_node_output: gs.Variable = graph_node.outputs[0]

    shape = graph_node_output.shape
    dtype = graph_node_output.dtype

    # Preserving Graph Structure (Dict)
    tf_layers_dict[graph_node_output.name] = {
        'optype': graph_node.op,
        'shape': shape,
        'dtype': dtype,
    }

    replace_erf_to_pseudo_erf = kwargs['replace_erf_to_pseudo_erf']

    # Generation of TF OP
    input_tensor = tf_layers_dict[graph_node_input.name]['tf_node'] \
        if isinstance(graph_node_input, gs.Variable) else graph_node_input

    if not replace_erf_to_pseudo_erf:
        tf_layers_dict[graph_node_output.name]['tf_node'] = \
            tf.math.erf(
                x=input_tensor,
                name=graph_node.name,
            )
    else:
        # https://stackoverflow.com/questions/457408/is-there-an-easily-available-implementation-of-erf-for-python
        eps = 1e-11
        x_abs = tf.math.abs(input_tensor)
        sign = tf.math.divide(input_tensor, tf.math.abs(input_tensor)+eps)
        a1 =  0.254829592
        a2 = -0.284496736
        a3 =  1.421413741
        a4 = -1.453152027
        a5 =  1.061405429
        p  =  0.3275911
        t = 1.0/(1.0 + p*x_abs)
        y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*tf.math.exp(-x_abs*x_abs)
        erf_tensor = sign*y
        tf_layers_dict[graph_node_output.name]['tf_node'] = erf_tensor

    # Generation of Debug Info
    tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
        make_tf_node_info(
            node_info={
                'tf_op_type': tf.math.erf,
                'tf_inputs': {
                    'x': input_tensor,
                },
                'tf_outputs': {
                    'output': tf_layers_dict[graph_node_output.name]['tf_node'],
                },
            }
        )
