# -*- encoding:utf-8 -*-
"""
    使用的金融时间序列分析模块, 模块真的方法真的参数都为格式化好的kl如下

    eg:
                close	high	low	p_change	open	pre_close	volume	date	date_week	key	atr21	atr14
    2016-07-20	228.36	229.800	225.00	1.38	226.47	225.26	2568498	20160720	2	499	9.1923	8.7234
    2016-07-21	220.50	227.847	219.10	-3.44	226.00	228.36	4428651	20160721	3	500	9.1711	8.7251
    2016-07-22	222.27	224.500	218.88	0.80	221.99	220.50	2579692	20160722	4	501	9.1858	8.7790
    2016-07-25	230.01	231.390	221.37	3.48	222.27	222.27	4490683	20160725	0	502	9.2669	8.9298
    2016-07-26	225.93	228.740	225.63	-1.77	227.34	230.01	41833	20160726	1	503	9.1337	8.7541
"""
import pandas as pd
import numpy as np
from collections import Iterable
from ultron.utilities.logger import kd_logger


def _df_dispatch(df, dispatch_func):
    """
    根据df的类型分发callable的执行方法，

    :param df: 格式化好的kl，或者字典，或者可迭代序列
    :param dispatch_func: 分发的可执行的方法
    """
    if isinstance(df, pd.DataFrame):
        # 参数只是pd.DataFrame
        return dispatch_func(df)
    elif isinstance(df, dict) and all(
        [isinstance(_df, pd.DataFrame) for _df in df.values()]):
        # 参数只是字典形式
        return [dispatch_func(df[df_key], df_key) for df_key in df]
    elif isinstance(df, Iterable) and all(
        [isinstance(_df, pd.DataFrame) for _df in df]):
        # 参数只是可迭代序列
        return [dispatch_func(_df) for _df in df]
    else:
        kd_logger.error('df type is error! {}'.format(type(df)))


def _df_dispatch_concat(df, dispatch_func):
    """
    根据df的类型分发callable的执行方法，如果是字典或者可迭代类型的返回值使用
    pd.concat连接起来

    :param df: 格式化好的kl，或者字典，或者可迭代序列
    :param dispatch_func: 分发的可执行的方法
    """

    if isinstance(df, pd.DataFrame):
        # 参数只是pd.DataFrame
        return dispatch_func(df)
    elif isinstance(df, dict) and all(
        [isinstance(_df, pd.DataFrame) for _df in df.values()]):
        # 参数只是字典形式
        return pd.concat([dispatch_func(df[df_key], df_key) for df_key in df],
                         axis=1)
    elif isinstance(df, Iterable) and all(
        [isinstance(_df, pd.DataFrame) for _df in df]):
        # 参数只是可迭代序列
        return pd.concat([dispatch_func(_df) for _df in df], axis=1)
    else:
        kd_logger.error('df type is error! {}'.format(type(df)))


def wave_change_rate(df):
    """
    eg:
        tsla = SymbolPd.make_kl_df('usTSLA')
        KLUtil.wave_change_rate(tsla)

        out:
        日振幅涨跌幅比：1.794156

    :param df: 格式化好的kl，或者字典，或者可迭代序列
    """

    def _wave_change_rate(p_df, df_name=''):
        wave = ((p_df.high - p_df.low) / p_df.pre_close) * 100
        # noinspection PyUnresolvedReferences
        wave_rate = wave.mean() / np.abs(p_df['p_change']).mean()
        return {'code': df_name, 'wave_rate': wave_rate}  # 1.80
        #print('{}日振幅涨跌幅比：{:2f}, {}日统计套利条件'.format(
        #    df_name, wave_rate, '具备' if wave_rate > 1.80 else '不具备'))

    return _df_dispatch(df, _wave_change_rate)


def p_change_stats(df):
    """
    eg :
        tsla = SymbolPd.make_kl_df('usTSLA')
        KLUtil.p_change_stats(tsla)

        out:

        日涨幅平均值1.861, 共260个交易日上涨走势
        日跌幅平均值-1.906, 共244个交易日下跌走势
        日平均涨跌比0.977, 上涨下跌数量比:1.066

    :param df: 格式化好的kl，或者字典，或者可迭代序列
    """

    def _p_change_stats(p_df, df_name=''):
        p_change_up = p_df[p_df['p_change'] > 0].p_change
        p_change_down = p_df[p_df['p_change'] < 0].p_change
        return {
            'code': df_name,
            'up_mean': p_change_up.mean(),
            'up_count': p_change_up.count(),
            'down_mean': p_change_down.mean(),
            'down_count': p_change_down.count(),
            'chg_mean': abs(p_change_up.mean() / p_change_down.mean()),
            'chg_ratio': p_change_up.count() / p_change_down.count()
        }

    return _df_dispatch(df, _p_change_stats)


def date_week_wave(df):
    """
    根据周几分析金融时间序列中的日波动:

    eg:
        tsla = SymbolPd.make_kl_df('usTSLA')
        KLUtil.date_week_wave(tsla)

        out:
            usTSLAwave
            date_week
            周一  3.8144
            周二  3.3326
            周三  3.3932
            周四  3.3801
            周五  2.9923

    :param df: ultronpy中格式化好的kl，或者字典，或者可迭代序列
    :return: pd.Series或者pd.DataFrame
    """

    def _date_week_wave(p_df, df_name=''):
        # 要改df所以copy
        df_copy = p_df.copy()
        wave_key = '{}wave'.format(df_name)
        # 计算波动: * 100目的是和金融序列中的p_change单位一致
        df_copy[wave_key] = (
            (df_copy.high - df_copy.low) / df_copy.pre_close) * 100
        dww = df_copy.groupby('date_week')[wave_key].mean()
        return dww

    return _df_dispatch_concat(df, _date_week_wave)


def date_week_win(df):
    """
    eg:
        tsla = SymbolPd.make_kl_df('usTSLA')
        KLUtil.date_week_win(tsla)

        out：
                0	1	win
            date_week
            0	44	51	0.5368
            1	55	48	0.4660
            2	48	57	0.5429
            3	44	57	0.5644
            4	53	47	0.470

    :param df: bupy中格式化好的kl，或者字典，或者可迭代序列
    :return: pd.Series或者pd.DataFrame
    """

    def _date_week_win(p_df, df_name=''):
        _df = p_df.copy()
        win_key = '{}win'.format(df_name)
        _df[win_key] = _df['p_change'] > 0
        _df[win_key] = _df[win_key].astype(int)

        dww = pd.concat([
            pd.crosstab(_df.date_week, _df[win_key]),
            _df.pivot_table([win_key], index='date_week')
        ],
                        axis=1)
        # 将周几这个信息变成中文
        # noinspection PyUnresolvedReferences
        return dww

    return _df_dispatch_concat(df, _date_week_win)


def date_week_mean(df):
    """
        eg:

        tsla = SymbolPd.make_kl_df('usTSLA')
        KLUtil.date_week_mean(tsla)

        out:
        周一    0.0626
        周二    0.0475
        周三    0.0881
        周四    0.2691
        周五   -0.2838
    :param df: ultronpy中格式化好的kl，或者字典，或者可迭代序列
    :return: pd.Series或者pd.DataFrame
    """

    def _date_week_win(p_df, df_name=''):
        dww = p_df.groupby('date_week')['p_change'].mean()
        # p_change变成对应的pchange
        dww = pd.DataFrame(dww)
        dww.rename(columns={'p_change': '{}_p_change'.format(df_name)},
                   inplace=True)
        return dww

    return _df_dispatch_concat(df, _date_week_win)


def bcut_change_vc(df, bins=None):
    """
    eg:
        tsla = SymbolPd.make_kl_df('usTSLA')
        KLUtil.bcut_change_vc(tsla)

        out:
                p_change	rate
        (0, 3]	209	0.4147
        (-3, 0]	193	0.3829
        (3, 7]	47	0.0933
        (-7, -3]	44	0.0873
        (-10, -7]	6	0.0119
        (7, 10]	3	0.0060
        (10, inf]	1	0.0020
        (-inf, -10]	1	0.0020

    :param df: ultronpy中格式化好的kl，或者字典，或者可迭代序列
    :param bins: 默认eg：[-np.inf, -10, -7, -3, 0, 3, 7, 10, np.inf]
    :return: pd.DataFrame
    """

    def _bcut_change_vc(p_df, df_name=''):
        dww = pd.DataFrame(pd.cut(p_df.p_change, bins=bins).value_counts())
        # 计算各个bin所占的百分比
        dww['{}rate'.format(
            df_name)] = dww.p_change.values / dww.p_change.values.sum()
        if len(df_name) > 0:
            dww.rename(columns={'p_change': '{}'.format(df_name)},
                       inplace=True)
        return dww

    if bins is None:
        bins = [-np.inf, -10, -7, -3, 0, 3, 7, 10, np.inf]
    return _df_dispatch_concat(df, _bcut_change_vc)


def qcut_change_vc(df, q=10):
    """
    eg:
        tsla = SymbolPd.make_kl_df('usTSLA')
        KLUtil.qcut_change_vc(tsla)

        out:
            change
        0	[-10.45, -3.002]
        1	(-3.002, -1.666]
        2	(-1.666, -0.93]
        3	(-0.93, -0.396]
        4	(-0.396, 0.065]
        5	(0.065, 0.48]
        6	(0.48, 1.102]
        7	(1.102, 1.922]
        8	(1.922, 3.007]
        9	(3.007, 11.17]

    :param df: ultronpy中格式化好的kl，或者字典，或者可迭代序列
    :param q: 透传qcut使用的q参数，默认10，10等分
    :return: pd.DataFrame
    """

    def _qcut_change_vc(p_df, df_name=''):
        dww = pd.qcut(p_df.p_change, q).value_counts().index.values
        # 构造Categories使用DataFrame套Series
        dww = pd.Series(dww)
        # 涨跌从负向正开始排序
        dww.sort_values(inplace=True)
        dww = pd.DataFrame(dww)
        # 排序后index重新从0开始排列
        dww.index = np.arange(0, q)
        dww.columns = ['{}change'.format(df_name)]
        return dww

    return _df_dispatch_concat(df, _qcut_change_vc)