from utils import convertToGeoFormat, coord_to_radians, max_dist
import pandas as pd
from sklearn.metrics.pairwise import haversine_distances
import numpy as np


def get_coord_per_end(dispersion_df: pd.DataFrame, multi_sum) -> list:
    """

    Parameters:
    -----------
    dispersion_df: DataFrame
        This dataframe containing the adreesses
    multi_sum: DataFrame
        This dataframe containing latitudes and longitudes of each API for each address in new_df
    Return:
    ------
    coord_per_end: list
        Return a list with a coordinate of each addres generated by each API
    """
    coord_per_end = []
    for end in dispersion_df.end_completo:
        x = [tuple(x) for x in multi_sum.xs(end).values]
        coord_per_end.append(x)
    return coord_per_end


def get_mean_per_end(multi_sum) -> list:
    """
    This function will group the values ​​by address and calculate the average 
    of the latitudes and longitudes.

    Parameters:
    ----------
    multi_sum: pd.DataFrame
        This dataframe containing latitudes and longitudes of each API for each address in dispersion_df
    Return:
    ------
    Two lists:
        mean of longitude and mean of latitude
    """
    lat_mean = []
    lon_mean = []
    for end in multi_sum.index.get_level_values('end_completo').unique():
        x = multi_sum.xs(end).longitude.mean()
        lon_mean.append(x)
        y = multi_sum.xs(end).latitude.mean()
        lat_mean.append(y)
    return lon_mean, lat_mean


def get_worst_GeoAPI(multi_sum):
    """
    This function is used in get_dispersion. It will get the API that has 
    the worst dispersion.

    Parameters:
    ----------
    dispersion_df: pd.DataFrame
        This dataframe containing the adreesses
    multi_sum: pd.DataFrame
        This dataframe containing latitudes and longitudes of each API for each address in dispersion_df

    Return:
    ------

    worst_api : list
        Return a list containing the worst api for each address
    """
    worst_api = []
    for end in multi_sum.index.get_level_values('end_completo').unique():
        y = tuple(multi_sum.xs(end).mean().values)
        y = coord_to_radians([y])
        x = [tuple(x) for x in multi_sum.xs(end).values]
        x   = coord_to_radians(x)
        worst = multi_sum.index.get_level_values('GeoAPI').unique()[
            np.nanargmax(haversine_distances(x, y))]
        worst_api.append(worst)
    return worst_api


def get_dispersion(geocoded_data, metrics: list) -> pd.DataFrame:
    """
    This function will do the scatter calculation based on a metric and 
    will return a new dataframe containing this information.

    Parameters:
    ----------
    geocoded_data: DataFrame,dict or str format
        Data containing the informations to get dispersion
        must contais the following coluns:
            latitude -
            longitude -
            geoAPI -
            end_completo -

    metric: str
        dispersion metric that will be used

    Return:
    ------
    DataFrame:
        This function will be return a new dataframe with a new column "dispersion"
    """
    geo_df = geocoded_data.copy(deep=True)

    if isinstance(geo_df, dict):
        geo_df = pd.DataFrame(geo_df)
        geo_df = convertToGeoFormat(geo_df)

    elif isinstance(geo_df, pd.DataFrame):
        geo_df = convertToGeoFormat(geo_df)

    # try:
    #     disp_metric = metrics.pop()
    #     print(disp_metric)

    # except IndexError:
    #     raise IndexError("it is necessary to inform at least one metric")
    # except AttributeError:
    #     raise AttributeError('Input must be a list')

    geo_df['longitude'] = geo_df.geometry.\
    apply(lambda p: p.x if not p.is_empty else None)

    geo_df['latitude'] = geo_df.geometry.\
    apply(lambda p: p.y if not p.is_empty else None)

    multi_sum = geo_df.set_index(['end_completo', 'GeoAPI'])[
        ['latitude', 'longitude']]
    dispersion_df = pd.DataFrame()
    dispersion_df = dispersion_df.assign(
        end_completo=multi_sum.index.get_level_values('end_completo').unique())


    lon, lat = get_mean_per_end( multi_sum)
    dispersion_df.insert(0, "longitude", lon, True)
    dispersion_df.insert(0, "latitude", lat, True)


    worst_api = get_worst_GeoAPI(multi_sum)
    dispersion_df.insert(0, "worst_api", worst_api, True)


    max_list = []
    for disp_metric in metrics:
        if disp_metric == "DistanceFromMean":
            for coord in get_coord_per_end(dispersion_df, multi_sum):
                _max = max_dist(coord_to_radians(coord))
                max_list.append(_max)
        dispersion_df.insert(0, 'Raio', max_list, True)



    return dispersion_df
