'''
Created on Dec 12, 2017

@author: fan

This function communicates with dataandgrid, grabs from there what we need in 
terms of data for a model instance. This is the gate way to grids.

Can not call these data, because data is dynamically generated by model.

'''

import logging
import numpy as np

import dataandgrid.genchoices as genchoices
import dataandgrid.genshocks as genshocks
import dataandgrid.genstates as genstates

import modelhh.functions.constraints as constraints

import pyfan.amto.array.mesh as supmesh

logger = logging.getLogger(__name__)


class GenModelData():
    """
    Generate Model Instances from Parameter Specifications    
    """

    def __init__(self, param_inst, utoday_inst, bdgt_inst):

        self.param_inst = param_inst

        'model option'
        self.choice_set_list = param_inst.model_option['choice_set_list']

        'model instances'
        self.utoday_inst = utoday_inst
        self.bdgt_inst = bdgt_inst

        self.kappa = param_inst.esti_param['kappa']
        self.K_DEPRECIATION = param_inst.esti_param['K_DEPRECIATION']

        'A treated as a parameter here'
        self.A = param_inst.data_param['A']

        'len'
        self.len_k_start = param_inst.grid_param['len_k_start']
        self.len_choices = param_inst.grid_param['len_choices']
        self.len_states = param_inst.grid_param['len_states']
        self.len_shocks = param_inst.grid_param['len_shocks']

        'b and b'
        self.max_kapital = param_inst.grid_param['max_kapital']
        self.min_kapital = param_inst.grid_param['min_kapital']
        self.max_netborrsave = param_inst.grid_param['max_netborrsave']
        self.min_netborrsave = param_inst.grid_param['min_netborrsave']

        'Shocks'
        self.max_eps = param_inst.grid_param['max_eps']
        self.min_eps = param_inst.grid_param['min_eps']
        self.mean_eps = param_inst.grid_param['mean_eps']
        self.std_eps = param_inst.grid_param['std_eps']
        self.len_eps = param_inst.grid_param['len_eps']
        self.drawtype_eps = param_inst.grid_param['drawtype_eps']
        self.seed_eps = param_inst.grid_param['seed_eps']

        self.max_eps_E = param_inst.grid_param['max_eps_E']
        self.min_eps_E = param_inst.grid_param['min_eps_E']
        self.mean_eps_E = param_inst.grid_param['mean_eps_E']
        self.std_eps_E = param_inst.grid_param['std_eps_E']
        self.len_eps_E = param_inst.grid_param['len_eps_E']

        'esti_param'
        self.R_INFORM_SAVE = param_inst.esti_param['R_INFORM_SAVE']
        self.R_INFORM_BORR = param_inst.esti_param['R_INFORM_BORR']
        self.R_INFORM = self.R_INFORM_SAVE
        self.R_FORMAL_SAVE = param_inst.esti_param['R_FORMAL_SAVE']
        self.R_FORMAL_BORR = param_inst.esti_param['R_FORMAL_BORR']

        self.BNF_SAVE_P = param_inst.esti_param['BNF_SAVE_P']
        self.BNF_BORR_P = param_inst.esti_param['BNF_BORR_P']
        self.BNI_LEND_P = param_inst.esti_param['BNI_LEND_P']
        self.BNI_BORR_P = param_inst.esti_param['BNI_BORR_P']

        self.BNF_SAVE_P_startVal = param_inst.grid_param['BNF_SAVE_P_startVal']
        self.BNF_BORR_P_startVal = param_inst.grid_param['BNF_BORR_P_startVal']
        self.BNI_LEND_P_startVal = param_inst.grid_param['BNI_LEND_P_startVal']
        self.BNI_BORR_P_startVal = param_inst.grid_param['BNI_BORR_P_startVal']

    def gen_states_data(self, k_tt_v, b_tt_v, fb_f_max_btp_v, eps_tt_v=None):
        """
        Generate input arrays for mjall
        
        state vectors: eps, k, b, fb_f_max, should be 2d 1 col
        choice vectors: all kp and bp choices, should be 2d M col N row            
        """

        # draw_type=1 for shock, these shocks are at quantiles, solved at max points
        # so that interpolation would not be extrapolation

        logger.info('A. States: Mesh State and Shock, and to 2d 1 col')

        eps_tt_v = genshocks.stateSpaceShocks(
            self.mean_eps, self.std_eps, self.len_shocks,
            seed=self.seed_eps, draw_type=self.drawtype_eps)

        k_tt, b_tt, fb_f_max_btp, eps_tt = self.mesh_states_shocks(
            eps_tt_v,
            k_tt_v, b_tt_v, fb_f_max_btp_v)

        return k_tt, b_tt, fb_f_max_btp, eps_tt

    def gen_mjall_data(self, data=None, data_map=None):

        logger.info('A. Return Observed States')
        logger.info('observed state vectors NOT meshed with shocks:k_tt_v, b_tt_v, fb_f_max_btp_v')

        if data is not None:
            k_tt = data[:, data_map['k']]
            b_tt = data[:, data_map['b']]
            fb_f_max_btp = self.gen_kb_states(k_tt_v=k_tt)
            eps_tt = data[:, data_map['eps']]
            k_tt_v, b_tt_v, fb_f_max_btp_v = k_tt, b_tt, fb_f_max_btp

            k_tt = np.reshape(k_tt, (-1, 1))
            b_tt = np.reshape(b_tt, (-1, 1))
            fb_f_max_btp = np.reshape(fb_f_max_btp, (-1, 1))
            eps_tt = np.reshape(eps_tt, (-1, 1))
        else:
            k_tt_v, b_tt_v, eps_tt_v = None, None, None
            k_tt_v, b_tt_v, fb_f_max_btp_v = self.gen_kb_states(k_tt_v, b_tt_v)
            k_tt, b_tt, fb_f_max_btp, eps_tt = \
                self.gen_states_data(k_tt_v, b_tt_v, fb_f_max_btp_v, eps_tt_v)

        cash, __ = self.utoday_inst.get_cash(A=self.A, eps_tt=eps_tt,
                                             k_tt=k_tt, b_tt=b_tt)

        logger.info('B. Choices: choices_kb_each returns 2d M by N Mat')

        __, \
        ib_i_ktp, is_i_ktp, fb_f_ktp, fs_f_ktp, \
        ibfb_i_ktp, fbis_i_ktp, \
        none_ktp, \
        ibfb_f_imin_ktp, fbis_f_imin_ktp, \
        ib_i_btp, is_i_btp, fb_f_btp, fs_f_btp, \
        ibfb_i_btp, fbis_i_btp, \
        none_btp, \
        ibfb_f_imin_btp, fbis_f_imin_btp = \
            genchoices.choices_kb_each(len_choices=self.len_choices,
                                       cont_choice_count=2,
                                       cash=cash,
                                       k_tt=k_tt,
                                       fb_f_max_btp=fb_f_max_btp,
                                       R_INFORM=self.R_INFORM, R_FORMAL_BORR=self.R_FORMAL_BORR,
                                       R_FORMAL_SAVE=self.R_FORMAL_SAVE,
                                       DELTA_DEPRE=self.K_DEPRECIATION, borr_constraint_KAPPA=self.kappa,
                                       BNF_SAVE_P=self.BNF_SAVE_P, BNF_SAVE_P_startVal=self.BNF_SAVE_P_startVal,
                                       BNF_BORR_P=self.BNF_BORR_P, BNF_BORR_P_startVal=self.BNF_BORR_P_startVal,
                                       BNI_LEND_P=self.BNI_LEND_P, BNI_LEND_P_startVal=self.BNI_LEND_P_startVal,
                                       BNI_BORR_P=self.BNI_BORR_P, BNI_BORR_P_startVal=self.BNI_BORR_P_startVal,
                                       choice_set_list=self.choice_set_list,
                                       K_interp_range={'K_max': self.max_kapital},
                                       B_interp_range={'B_max': self.max_netborrsave})

        return k_tt_v, b_tt_v, fb_f_max_btp_v, \
               eps_tt, k_tt, b_tt, \
               fb_f_max_btp, \
               ib_i_ktp, is_i_ktp, fb_f_ktp, fs_f_ktp, \
               ibfb_i_ktp, fbis_i_ktp, \
               none_ktp, \
               ibfb_f_imin_ktp, fbis_f_imin_ktp, \
               ib_i_btp, is_i_btp, fb_f_btp, fs_f_btp, \
               ibfb_i_btp, fbis_i_btp, \
               none_btp, \
               ibfb_f_imin_btp, fbis_f_imin_btp

    def gen_kb_states(self, k_tt_v=None, b_tt_v=None):
        'Get State Vectors: already meshed'
        """k_tt_v, b_tt_v could be data or simulation last round"""

        return_fb_f_max_btp_v_only = True

        if (k_tt_v is None and b_tt_v is None):
            k_tt_v, b_tt_v, self.param_inst = genstates.state_grids(
                self.param_inst, self.bdgt_inst,
                self.max_kapital, self.min_kapital, self.max_netborrsave,
                self.len_states, self.len_k_start,
                fixed_unif_grid=True, seed=1230)

            return_fb_f_max_btp_v_only = False

        'Get borrowing bound vector'
        fb_f_max_btp_v = constraints.get_borrow_constraint(
            self.kappa, k_tt_v, self.R_FORMAL_BORR)

        if (return_fb_f_max_btp_v_only):
            return fb_f_max_btp_v
        else:
            return k_tt_v, b_tt_v, fb_f_max_btp_v

    def mesh_states_shocks(self, eps_tt_v, k_tt_v, b_tt_v, fb_f_max_btp_v):

        'Get Shock Vector'
        #         eps_tt_v = genshocks.stateSpaceShocks(
        #                     self.mean_eps, self.std_eps, self.len_shocks,
        #                     seed=1230, draw_type=0)

        logger.info('Mesh States and Shocks')
        logger.info('self.k_tt_v, self.b_tt_v obtained with constructor init gen_kb_states() invoke')
        mat_one = np.column_stack([k_tt_v, b_tt_v, fb_f_max_btp_v])
        mat_two = np.column_stack([eps_tt_v])

        logger.debug('mat_one:[k_tt_v, b_tt_v, fb_f_max_btp_v]')
        logger.debug('mat_two:[eps_tt_v]')
        mat_one_mesh, mat_two_mesh = supmesh.two_mat_mesh(
            mat_one, mat_two, return_joint=False, return_single_col=False)

        k_tt = mat_one_mesh[:, 0]
        b_tt = mat_one_mesh[:, 1]
        fb_f_max_btp = mat_one_mesh[:, 2]
        eps_tt = mat_two_mesh[:, 0]

        k_tt = np.reshape(k_tt, (-1, 1))
        b_tt = np.reshape(b_tt, (-1, 1))
        fb_f_max_btp = np.reshape(fb_f_max_btp, (-1, 1))
        eps_tt = np.reshape(eps_tt, (-1, 1))

        logger.debug('states join: k_tt, b_tt, fb_f_max_btp, eps_tt')
        logger.debug(['shape(k_tt, b_tt, fb_f_max_btp, eps_tt)',
                      np.shape(np.column_stack((k_tt, b_tt, fb_f_max_btp, eps_tt)))])
        logger.debug('states:\n%s', np.column_stack((k_tt, b_tt, fb_f_max_btp, eps_tt)))

        return k_tt, b_tt, fb_f_max_btp, eps_tt

    def gen_minterp_data(self, k_tt_v, b_tt_v, fb_f_max_btp_v):

        #         'A.'
        #         epsvec = genshocks.stateSpaceShocks(
        #                    self.mean_eps, self.std_eps, self.len_eps,
        #                    draw_type=1)
        #         K_V, B_V, __, eps_V = self.mesh_states_shocks(epsvec)

        epsvec_E = genshocks.stateSpaceShocks(
            self.mean_eps_E, self.std_eps_E, self.len_eps_E,
            draw_type=0)
        K_Veps, B_Veps, __, eps_Veps = self.mesh_states_shocks(
            epsvec_E,
            k_tt_v, b_tt_v, fb_f_max_btp_v)

        return B_Veps, K_Veps, eps_Veps
