# -*- coding: utf-8 -*-
"""
Created on Mon Jan 25 14:29:51 2021

@author: j.reul

The module incorporates the class "Core", which inherits functionality from the 
classes "Estimation", "Simulation", "PostAnalysis" and "Visualitation".
It defines the core attributes, structure and type of the
discrete choice models to be simulated or estimated. 
Two different types of discrete choice models are differentiated:
    - Multinomial logit models (MNL)
    - Mixed logit models (MXL) with discrete points as parameters

"""

import numpy as np
import pandas as pd
import pickle
import os

from .estimation import Estimation
from .post_analysis import PostAnalysis
from .simulation import Simulation
from . import config

class Core(Estimation, Simulation, PostAnalysis):
    """
    The class "Core" defines the attributes, structure and type of the 
    mixed logit model, being built.
    
    The class inherits functionality from the classes "Estimation"
    and "Simulation"
    """
    
    def __init__(self, **kwargs):
        """
        Object initialization of class "Core".
             
        Parameters
        ----------
        kwargs param : dict
            Holds the names of all attributes in utility function.
            list(param) = ['constant', 'variable'] --> Disctinction between variables
            that are constant over alternatives or vary, respectively.
            list(param['constant']) = ['fixed', 'random'] --> Distinction
            between variables that are not randomly distributed within
            a group of decision-makers ('fixed') and those, that are randomly
            distributed with a discrete distribution ('random')
                                    
        kwargs max_space : int
            Maximum number of data points within parameter space.
            
        kwargs alt: int
            Number of discrete choice alternatives.
            
        kwargs equal_alt: int
            Maximum number of equal alternatives (with different attributes) 
            in a single observation/choice set. 
            E.g. mode choice: Available are two buses and one car. This would
            lead to a maximum of two equal alternatives (two buses).
                        
        kwargs norm_alt : int
            Defines, which alternative shall be normalized. Defaults to 0.
                        
        kwargs include_weights : boolean
            If True, the model searches for a column in the input data, which
            is called "weight". This column indicates the weight of each 
            observation in the input data. Defaults to True.
                        
        kwargs data_name : str
            Defines the filename of the file, which holds the base data
            to calibrate the MNL or MXL model on.
        
        kwargs data_index : array
            An array, which holds the index values of those datapoints
            of the base data, which shall be used for estimation.
        
        initial_point_name : str
            The filename of the file, which holds the estimated MNL parameters.
                    
        kwargs dc_type : str
            Determines the model type for simulation: MNL or MXL model. 
            Depends on which estimated model parameters are available.
                
        kwargs dict_specific_travel_cost : dictionary
            External specification of transport costs. Relevant for the 
            simulation of mode choice.
        
        Returns
        -------
        None.

        """
        
        self.model_type = kwargs.get('model_type', 'estimation')
        
        if self.model_type == 'simulation':
                
            self.asc_offset_hh_cars = config.asc_offset_hh_cars
            
            #load previously estimated model parameters, 
            #if model-type is simulation.
            self.initial_point_cars = config.initial_point_cars
            self.initial_point_mode = config.initial_point_mode
                
            self.log_param = config.log_param
            dict_specific_travel_cost_ext = kwargs.get('dict_specific_travel_cost', {})
            cc_cost_ext = kwargs.get('cc_cost', False)
            if len(dict_specific_travel_cost_ext) > 0:
                self.dict_specific_travel_cost = dict_specific_travel_cost_ext
            else:
                self.dict_specific_travel_cost = config.dict_specific_travel_cost
            if cc_cost_ext:
                self.cc_cost = cc_cost_ext
            else:
                self.cc_cost = config.cc_cost_2020
            
            self.param = kwargs.get('param', {})
            if self.param:
                self.no_constant_fixed = len(self.param['constant']['fixed'])
                self.no_constant_random = len(self.param['constant']['random'])
                self.no_variable_fixed = len(self.param['variable']['fixed'])
                self.no_variable_random = len(self.param['variable']['random'])
                
        else:
            #define path to input data
            PATH_MODULE = os.path.dirname(__file__)
            sep = os.path.sep
            self.data_name = kwargs.get("data_name", False)
            self.initial_point_name = kwargs.get("initial_point_name", False)
            self.PATH_InputData = PATH_MODULE + sep + 'InputData' + sep
            self.PATH_ModelParam = PATH_MODULE + sep + 'ModelParam' + sep
            self.PATH_Visualize = PATH_MODULE + sep + 'Visualizations' + sep
            
            #random or fixed within parameter space of Mixed Logit
            self.param = kwargs.get('param', False)
            self.count_c = kwargs.get('alt', False)
            self.count_e = kwargs.get('equal_alt', False)
            
            #check, if necessary values have been specified.
            if self.param == False:
                raise ValueError('Argument -param- needs to be specified!')                
            if self.count_c == False:
                raise ValueError('Argument -alt- needs to be specified!')
            if self.count_e == False:
                raise ValueError('Argument -equal_alt- needs to be specified!')             
            if self.data_name == False:
                raise ValueError('Argument -data_name- needs to be specified!')                
            
            self.no_constant_fixed = len(self.param['constant']['fixed'])
            self.no_constant_random = len(self.param['constant']['random'])
            self.no_variable_fixed = len(self.param['variable']['fixed'])
            self.no_variable_random = len(self.param['variable']['random'])
                            
            self.data_index = kwargs.get("data_index", np.array([]))

            self.include_weights = kwargs.get("include_weights", True)
            
            if self.initial_point_name:
                with open(self.PATH_ModelParam + self.initial_point_name + ".pickle", 'rb') as handle:
                    self.initial_point = pickle.load(handle)  
                                
            print('Data wrangling.')
            
            try:
                self.data = pd.read_csv(self.PATH_InputData + self.data_name + ".csv", sep = ",")
                if len(list(self.data)) == 1:
                    raise ValueError('Check the separator of the imported .csv-files. Should be ","-separated.')
            except:
                with open(self.PATH_InputData + self.data_name + ".pickle", 'rb') as handle:
                    self.data = pickle.load(handle)  
                    
            self.data = self.data.reset_index(drop=True)
                        
            #get data-points from indicated indices
            if self.data_index.size:
                self.data = self.data.iloc[self.data_index]
                                
            print('Length of dataset:', len(self.data))
                                            
            #define choices and availabilities
            #scale availabilities, if weights are provided in input-data
            if "weight" in self.data.columns and self.include_weights == True:
                self.weight_vector = self.data["weight"].values.copy()
            else:
                self.weight_vector = np.ones(shape=len(self.data), dtype=np.float64)
            self.choice = np.zeros((self.count_c,self.count_e,len(self.data)), dtype=np.int64)
            self.av = np.zeros((self.count_c,self.count_e,len(self.data)), dtype=np.float64)
            for c in range(self.count_c):
                for e in range(self.count_e):
                    self.choice[c][e] = self.data["choice_" + str(c) + "_" + str(e)].values
                    self.av[c][e] = self.data["av_" + str(c) + "_" + str(e)].values