# -*- coding: utf-8 -*-
"""Utility functions for all event stream to SPADL converters.

A converter should implement 'convert_to_actions' to convert the events to the
SPADL format.

"""
import pandas as pd  # type: ignore
from pandera.typing import DataFrame

from . import config as spadlconfig


def _fix_clearances(actions: DataFrame) -> DataFrame:
    next_actions = actions.shift(-1)
    next_actions[-1:] = actions[-1:]
    clearance_idx = actions.type_id == spadlconfig.actiontypes.index('clearance')
    actions.loc[clearance_idx, 'end_x'] = next_actions[clearance_idx].start_x.values
    actions.loc[clearance_idx, 'end_y'] = next_actions[clearance_idx].start_y.values

    return actions


def _fix_direction_of_play(actions: DataFrame, home_team_id: int) -> DataFrame:
    away_idx = (actions.team_id != home_team_id).values
    for col in ['start_x', 'end_x']:
        actions.loc[away_idx, col] = spadlconfig.field_length - actions[away_idx][col].values
    for col in ['start_y', 'end_y']:
        actions.loc[away_idx, col] = spadlconfig.field_width - actions[away_idx][col].values

    return actions


min_dribble_length: float = 3.0
max_dribble_length: float = 60.0
max_dribble_duration: float = 10.0


def _add_dribbles(actions: DataFrame) -> DataFrame:
    next_actions = actions.shift(-1, fill_value=0)

    same_team = actions.team_id == next_actions.team_id
    # not_clearance = actions.type_id != actiontypes.index("clearance")

    dx = actions.end_x - next_actions.start_x
    dy = actions.end_y - next_actions.start_y
    far_enough = dx ** 2 + dy ** 2 >= min_dribble_length ** 2
    not_too_far = dx ** 2 + dy ** 2 <= max_dribble_length ** 2

    dt = next_actions.time_seconds - actions.time_seconds
    same_phase = dt < max_dribble_duration
    same_period = actions.period_id == next_actions.period_id

    dribble_idx = same_team & far_enough & not_too_far & same_phase & same_period

    dribbles = pd.DataFrame()
    prev = actions[dribble_idx]
    nex = next_actions[dribble_idx]
    dribbles['game_id'] = nex.game_id
    dribbles['period_id'] = nex.period_id
    dribbles['action_id'] = prev.action_id + 0.1
    dribbles['time_seconds'] = (prev.time_seconds + nex.time_seconds) / 2
    if 'timestamp' in actions.columns:
        dribbles['timestamp'] = nex.timestamp
    dribbles['team_id'] = nex.team_id
    dribbles['player_id'] = nex.player_id
    dribbles['start_x'] = prev.end_x
    dribbles['start_y'] = prev.end_y
    dribbles['end_x'] = nex.start_x
    dribbles['end_y'] = nex.start_y
    dribbles['bodypart_id'] = spadlconfig.bodyparts.index('foot')
    dribbles['type_id'] = spadlconfig.actiontypes.index('dribble')
    dribbles['result_id'] = spadlconfig.results.index('success')

    actions = pd.concat([actions, dribbles], ignore_index=True, sort=False)
    actions = actions.sort_values(['game_id', 'period_id', 'action_id']).reset_index(drop=True)
    actions['action_id'] = range(len(actions))
    return actions
