"""
qclib auxiliary functions
"""

from qiskit import execute, Aer
import numpy as np
from scipy import sparse


def get_counts(circ):
    """
    Parameters
    ----------
    circ: QuantumCircuit (with measurement operations)

    Returns
    -------
    counts: output generated by the quantum circuit
    """
    backend = Aer.get_backend('qasm_simulator')
    counts = execute(circ, backend).result().get_counts()
    return counts


def get_state(circ):
    """
    Parameters
    ----------
    circ: QuantumCircuit

    Returns
    -------
    state_vector: state generated by the quantum circuit
    """

    backend = Aer.get_backend('statevector_simulator')
    state_vector = execute(circ, backend).result().get_statevector()

    return state_vector

def transform_dataset(dataset, sort=False):
    """
        Auxiliary procedure for the Park's and Adapted Tugenberger's method.
        Transforms the dataset in to a list in which each entry is a list
        of tuples, such as :

        [(x_0, b_0), (x_1, b_1), ..., (x_n, b_n)]

        Where x_y is the real/complex value in the y-th component of
        feature vector and b_y is the binary string associated to the
        y-th value.
        This facilitates the storage of sparse vectors into the state
    :param dataset: Array or list of values to be transformed
    :return: List of tuples and the number of qubits necessary to
            encode the features into the state
    """

    if not isinstance(dataset, (list, np.ndarray)):
        raise Exception("Expected List of Numpy " +
                        "ndarray types, but got {} instead".format(type(dataset)))

    temp_dataset = dataset

    # Verifying if the dataset is a unidimesional iterable
    if not isinstance(dataset[0], (list, np.ndarray)):

        if isinstance(dataset, list):
            temp_dataset = [temp_dataset]

        elif isinstance(dataset, np.ndarray):
            temp_dataset = np.array([dataset.tolist()])

    dataset_size = 0

    if isinstance(temp_dataset, list):
        dataset_size = len(temp_dataset)

    elif isinstance(temp_dataset, np.ndarray):
        dataset_size = temp_dataset.shape[0]

    # Number of qubits for encoding the feature into the state
    n_qbits = int(np.floor(np.log2(len(temp_dataset[0]) * dataset_size)))

    transfomed_dataset = []

    feature_counter = 0

    for _, feature_vector in enumerate(temp_dataset):

        for _, feature in enumerate(feature_vector):

            if feature != 0:
                binary_state = format(feature_counter, 'b').zfill(n_qbits)
                transfomed_dataset.append((binary_state.count("1"),binary_state, feature))
            feature_counter += 1
    if sort:
        transfomed_dataset_sort = sorted(transfomed_dataset, key=lambda x: -x[0], reverse=True)
        transfomed_dataset_sort = list(map(lambda t: t[1:], transfomed_dataset_sort))
        return transfomed_dataset_sort, n_qbits

    else:
        transfomed_dataset = list(map(lambda t: t[1:], transfomed_dataset))
        return transfomed_dataset, n_qbits


def replace_all_values_with(new_value, dataset):
    """
        Given a list of tuples (v, b),where v is the value
        and b is the binary pattern associated to it.
        this procedure performs the task of replacing
        v with the new_value
    :param new_value: Value to replate the v in all the tuples
                      (v, b)
    :param dataset: List of tuples where the values are to be
                    replaced
    :return: new list of tuples
    """

    new_dataset = []
    for _, binary_pattern in dataset:
        new_dataset.append((new_value, binary_pattern))

    return new_dataset


def build_list_of_quibit_objects(quantum_register):
    """
        Buid a list of Qubit objects to be used as
        input to some procedure of the qiskit framework
    :param quantum_register: Quantum register with the qubits
    :return: Qubits list
    """
    qubits_list = []

    for i in range(quantum_register.size):
        qubits_list.append(quantum_register[quantum_register.size - i - 1])

    return qubits_list


def verify_interval_in_state_vector(statevector, start, finish):
    """
        Verifies if there is at least one non zero entry in
        a given interval in the state vectors cells, and
        returns true if positive
    :param statevector: state vector to be processed
    :param start: start of the interval
    :param finish: end of the interval
    :return: Boolean True if a non zero entry has been found
    """
    found = False
    for cell_idx in range(start, finish):

        cell_value = statevector[cell_idx]
        if cell_value != 0:
            found = True
            break
    return found


def verify_trigonometric_interval(value):
    """
        Verify if a certain value is inside the interval
        of the domain of the tirgonometric functions
        cosine and sine, [-1, 1]
    :param value: Real value to be evaluated
    :return: Value, if the value is inside the domain
             Updated value, if the value is outside the
             domain
    """

    value = min(value, 1)
    if value < -1:
        value = -1
    return value


def _count_ones(pattern):
    return pattern[0].count("1")


def random_sparse(nbits, density):
    '''
    Creates a random input for sparse quantum state preparation
    nbits: int number of qubits
    density: float in [0,1]

    returns
    bin_data: [(binary_string_k, float_k)] k = 0 ... n
    '''

    data = sparse.random(2 ** nbits, 1, density, format="dok")

    rows, _ = data.nonzero()
    bin_data = []

    length = sparse.linalg.norm(data)

    for k in rows:
        bin_data.append((format(k, "0" + str(nbits) + "b"), data[k, 0] / length))

    bin_data.sort(key=_count_ones)
    return bin_data

def _double_sparse_binary(nbits, log_size, p_1, p_0):

    bin_data = []
    while len(bin_data)< 2**log_size:
        lst = np.random.choice(2, nbits, p=[p_1, p_0]).tolist()
        lst_str = map(lambda s: str(s), lst)
        binary =''
        binary = binary.join(lst_str)

        if binary not in bin_data:
            bin_data.append(binary)
    return  bin_data


def double_sparse(nbits, log_size, p_1):
    """
    Parameters
    ----------
    nbits (int): number of qubits
    log_size (int): log_2(number of amplitudes)
    p_1 (float): probability of qubit equal to one

    Returns
    -------
    \\sum_{k} x_k |p_k>, each bit of p_k is equal to 1 with probability p1
    """
    data = np.random.rand(2 ** log_size)
    length = np.linalg.norm(data)
    data = (1/length) * data

    binary = _double_sparse_binary(nbits, log_size, 1 - p_1, p_1)
    bin_data = [(binary[i], data[i]) for i in range(2 ** log_size)]

    bin_data.sort(key=_count_ones)
    return bin_data


def _compute_matrix_angles(feature, norm):
    """
        Compute the angles of the matrix U3 necessary for encoding
        the phase of the state
    :param feature: Complex or float, feature to be stored
    :param norm: remaining norm to be used to compute the angles
    :return: the angles alpha, beta and phi of the operator U3
    """
    alpha = 0
    beta = 0
    phi = 0

    if isinstance(feature, complex):
        phase = np.abs(np.power(feature, 2))

        if (norm - phase) < 0:
            norm = np.around(norm, decimals=4)

        cos_value = np.sqrt((norm - phase) / norm)
        cos_value = verify_trigonometric_interval(cos_value)
        alpha = 2 * (np.arccos(cos_value))
        beta = np.arccos(- feature.real / np.sqrt(np.abs(np.power(feature, 2))))

        if feature.imag < 0:
            beta = 2 * np.pi - beta

        phi = - beta

    else:
        sin_value = - feature / np.sqrt(norm)
        sin_value = verify_trigonometric_interval(sin_value)
        alpha = 2 * (np.arcsin(sin_value))

    return alpha, beta, phi