from .fun import ispump
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score

import IPython,os,time,copy
import pandas as pd
import numpy as np
import seaborn as sns
import warnings; warnings.filterwarnings("ignore")
import matplotlib.cm as cm
import scipy.stats as sps

class model(object):

    def demo():
        '''demo is used to demonstrate typical examples about this class.'''
        demostr = '''
import clubear as cb
#clubear.csv is generated by cb.manager.demo()
pm=cb.pump('clubear.csv')
pm.qlist=['age','height','logsales','price']
md=cb.model(pm).ols('logsales',tv=False)
md=cb.model(pm).mrs('logsales')
tk=cb.tank(pm)
tk.app(lambda x: 1*(x>2),'logsales')
md=cb.model(tk).logit('logsales')
'''
        print(demostr)

    def __init__(self, pm):
        '''Initialization: check whether pm is a PUMP!.'''
        self.pm = pm
        
        '''check whether the input is a pump first'''
        if not ispump(self.pm): print('model: The input seems not a valid pump.'); return
        
    def mrs(self,y='',keep=[],niter=10,disp=True):
        '''mrs is used to produce marginal r.squared for all x w.r.t y.'''
        
        '''check the initial conditons carefully'''
        if not ispump(self.pm): print('model.mrs: The input seems not a valid pump.'); return
        df=self.pm.go();heads=list(df.columns) 
        if not isinstance(niter,int): print('model.mrs: The niter must be an int.'); return
        if niter<1: print('model.mrs: The niter should be no less than 1.'); return
        if not isinstance(y,str): print('model.mrs: The input y should be str.'); return
        y=y.strip()
        if len(y)==0: print('model.mrs: The input y cannot be empty.'); return
        if y not in heads: print('model.mrs: The input y was not found in pump df.'); return
        if df.dtypes[y]=='object': print('model.mrs: The input y is found to be non-numeric.'); return
        qlist=[each for each in heads if df.dtypes[each]!='object'];qlist.remove(y)
        clist=[each for each in heads if df.dtypes[each]=='object']
        if not isinstance(disp,bool): print('plot.: The input disp should be a bool.'); return
        if not isinstance(keep,list): print('model.mrs: The keep list should be a list'); return
        if len(keep)>0: 
            qlist=[each for each in qlist if each in keep]
            clist=[each for each in clist if each in keep]
            heads=qlist+clist
        
        Mrq=pd.Series(np.zeros(len(heads)),heads)
        start_time=time.time()
        for n in range(niter):
            df=self.pm.go()
            
            flag=list(np.isnan(df[y]))
            flag=[not each for each in flag]
            df=df.iloc[flag]                
            df[y]=(df[y]-np.mean(df[y]))/np.std(df[y])
            
            for each in qlist:
                df1=df[[y,each]]
                flag=list(np.isnan(df1[each]))
                flag=[not eachf for eachf in flag]
                df1=df1.iloc[flag] 
                R2=df1.corr()[y][each];R2=R2**2
                Mrq[each]=Mrq[each]+R2
                
            for each in clist:
                df2=df[[y,each]];
                df2=df2.groupby(each)
                N=df2.size()-1;
                VAR=df2.var();
                VAR=pd.Series(VAR[y],VAR.index)
                R2=1-np.nansum(N*VAR)/np.nansum(N)
                Mrq[each]=Mrq[each]+R2                
                
            R2=Mrq/(n+1)*100; R2=R2[np.isnan(R2)==False];R2[y]=100
            pos=np.argsort(-R2)
            yval=[R2[each] for each in pos]
            xval=[R2.index[each] for each in pos]
            
            '''for dynamic outputs'''
            if not disp: continue
                
            maxnum=50 #the maximum number of heads for display
            end_time=time.time();elapsed_time=end_time-start_time            
            IPython.display.clear_output(wait=True)
            fig,ax = plt.subplots(); fig.set_figwidth(15); 
            myfig_height=0.25*min(len(xval),maxnum)
            fig.set_figheight(myfig_height)
            
            progress = '%.1f'%((n+1)/niter*100)+'%';
            progress=' '*(6-len(progress))+progress
            titlestr='Task accomplished = ['+progress+']; '
            titlestr=titlestr+'    Time elapsed = ['+('%.1f'%elapsed_time)+'] seconds.'
            ax.set_title(titlestr)
            ax.set_ylabel('Ranked variables by MRS')
            ax.set_xlabel('Marginal R.Squared in % for '+y)
            sns.barplot(yval[:maxnum],xval[:maxnum],palette="Set2")                        
            plt.show() 
            
        Mrs=pd.Series(yval,xval)
        return Mrs
       
    def ols(self,y='',niter=10,tv=False,disp=True):
        '''ols is used to produce ols estimate for a linear model.'''
        
        '''check whether the input is a pump first'''
        df=self.pm.go();heads=list(df.columns) 
        if not isinstance(niter,int): print('model.ols: The niter must be an int.'); return
        if niter<1: print('model.ols: The niter should be no less than 1.'); return
        if not isinstance(y,str): print('model.ols: The input y should be str.'); return
        y=y.strip()
        if len(y)==0: print('model.ols: The input y cannot be empty.'); return
        if y not in heads: print('model.ols: The input y was not found in pump df.'); return
        if df.dtypes[y]=='object': print('model.ols: The input y is found to be non-numeric.'); return
        if not isinstance(tv,bool): print('model.ols: The input tv should be a bool.'); return
        if not isinstance(disp,bool): print('model.ols: The input disp should be a bool.'); return
        
        xlist=df.columns
        xlist=[each.strip() for each in xlist if df.dtypes[each] != 'object']
        xlist=[each for each in xlist if each != y]
        if len(xlist)==0: print('model.ols: At least one x should be included.'); return
        
        xlist=sorted(xlist);ncov=len(xlist) #without intercept
        XXS=np.eye(ncov)*1.0e-12;XYS=np.zeros(ncov)
        NS=0;RSS=0;RSS0=0;start_time=time.time();beta=[0]*ncov;
        for n in range(niter):
            df=self.pm.go()
            flag=np.mean(np.array(df[xlist+[y]]),axis=1)
            flag=np.isnan(flag)
            df=df.iloc[flag==False]
            if df.shape[0]==0: print('model.ols: Too many NaN found in data.'); return
            
            '''compute R2 before updating for out-of-sample purpose'''
            X=np.array(df[xlist]);Y=np.array(df[y]);
            
            if n>0:
                RSS0=RSS0+np.sum((Y-np.mean(Y))**2)
                resid=Y-np.matmul(X,beta)
                RSS=RSS+np.sum(resid**2)            
                stde=np.sqrt(RSS/NS)
                R2=100*(RSS0-RSS)/RSS0
            
            '''update the ols estimate for beta'''
            XXS=XXS+np.matmul(X.T,X)+np.eye(ncov)*1.0e-6
            XYS=XYS+np.matmul(X.T,Y)
            NS=NS+len(Y)
            INV=np.linalg.inv(XXS/NS)
            beta=np.matmul(INV,XYS/NS)            
                       
            if n==0: continue
            SE=np.sqrt(np.diag(INV))*stde/np.sqrt(NS)
            tstat=beta/SE
            pvalue=2*(1-sps.norm.cdf(np.abs(tstat)))

            out=pd.DataFrame(list(zip(beta,SE*100,tstat,pvalue*100)))
            out.index=xlist
            out.columns=['Estimate','StandErr','tStat','pValue']  
                       
            '''for dynamic outputs'''
            if not disp: continue
            end_time=time.time();elapse_time=end_time-start_time            
            IPython.display.clear_output(wait=True)
            pd.set_option('display.float_format', lambda x: '%.3f' % x)
            pd.set_option('display.max_columns',100)
            pd.set_option('display.width',100)
            progress = np.round((n+1)/niter*100, 2)

            '''output to screen for interactive analysis'''
            if tv:
                print('Time elapsed:',('%.1f'%elapse_time),'seconds',end=' ')
                print('with averaged subsample sizes', ('%.1f'%(NS/(n+1))), 'and R.Squared = ',('%.1f'%R2),'%')
                print('Task accomplished: ', ('%.1f'%progress),
                      '% for a total of ', niter, 'random replications.')
                print('');print(np.round(out, 3));print('')
                print('* Estimate: the ordinary least squares estimate for standardized x variables.')
                print('* StandErr: the standard error in %; t.Stat: the t-statistics = Estimate/Stand.Err.')
                print('* pValue: the p-value computed by normal approximation in %.')
            else:
                fig,ax = plt.subplots(); fig.set_figwidth(15); fig.set_figheight(7.5)
                tmpout=out;yval=list(tmpout.tStat); xval=list(tmpout.index)
                titlestr='Task accomplished = ['+str(n+1)+'/'+str(niter)+']; '
                titlestr=titlestr+'    Time elapsed = ['+('%.1f'%elapse_time)+'] seconds;'
                titlestr=titlestr+'    R.Squared = ['+('%.1f'%R2)+'%].'
                ax.set_title(titlestr)
                ax.set_ylabel(y)
                ax.set_xlabel('tStat = ols estimates / standard error')
                colors = cm.rainbow(1-tmpout['pValue'])
                sns.barplot(yval,xval,palette=colors)
                plt.show() 
            
        return out
        
    def logit(self,y='',niter=10,tv=False,disp=True):
        '''loigt is used to produce MLE for a logistic regression model.'''

        df=self.pm.go();heads=list(df.columns) 
        if not isinstance(niter,int): print('model.logit: The niter must be an int.'); return
        if niter<1: print('model.logit: The niter should be no less than 1.'); return
        if not isinstance(y,str): print('model.logit: The input y should be str.'); return
        y=y.strip()
        if len(y)==0: print('model.logit: The input y cannot be empty.'); return
        if y not in heads: print('model.logit: The input y was not found in pump df.'); return
        if df.dtypes[y]=='object': print('model.logit: The input y is found to be non-numeric.'); return
        if not isinstance(tv,bool): print('model.logit: The input tv should be a bool.'); return
        if not isinstance(disp,bool): print('model.logit: The input disp should be a bool.'); return

        xlist=df.columns
        xlist=[each.strip() for each in xlist if df.dtypes[each] != 'object']
        xlist=[each for each in xlist if each != y]
        if len(xlist)==0: print('model.logit: At least one x should be included.'); return

        xlist=sorted(xlist);ncov=len(xlist); NS=0;auc_total=0
        beta0=[0]*ncov; BETA=[0]*ncov;SE=[0]*ncov;MyBeta=[0]*ncov
        start_time=time.time()
        for n in range(niter):
            df=self.pm.go()

            '''eliminate nan observations'''
            flag=np.mean(np.array(df[xlist+[y]]),axis=1)
            flag=np.isnan(flag)
            df=df.iloc[flag==False]
            if df.shape[0]==0: print('model.logit: Too many NaN found in data.'); return

            '''get X and Y, and then normalize X'''                 
            X=np.array(df[xlist]);Y=np.array(df[y]);ss=len(Y);NS=NS+ss
            
            '''check whether Y is a binary variable'''
            tmpY=[each for each in Y if each in [0,1]]            
            if len(tmpY)!=len(Y): print('model.logit: The response y should be a binary variable.'); return            

            '''initialize for Newton Raphson'''  
            max_iter_count=100;dist=1;miter=0;
            while((miter<=max_iter_count) and (dist>1.0e-6)):
                linear=X.dot(beta0)
                pp=1./ (1 + np.exp(-linear))
                L1=X.T.dot(Y-pp)/ss
                weight=pp*(1-pp);weight=weight.reshape([ss,1]);
                temp=np.matmul(weight,np.ones([1,ncov]))
                L2=(X*temp).T.dot(X)/ss+np.eye(ncov)*1.0e-12
                INV=np.linalg.inv(L2)
                step=np.matmul(INV,L1)
                beta0=beta0+step
                miter=miter+1
                dist=np.max(abs(step))
            
            '''compute auc before updating for out-of-sample purpose'''
            if n>0:
                auc_total=auc_total+roc_auc_score(Y,X.dot(MyBeta))
                AUC=100*auc_total/n
            
            BETA=BETA+beta0;SE=SE+np.diag(INV)/ss            
            MyBeta=BETA/(n+1);beta0=MyBeta
            MySE=np.sqrt(SE+1.0e-12)/(n+1);tstat = MyBeta/MySE
            pvalue = 2*(1-sps.norm.cdf(np.abs(tstat)))            

            '''output results'''
            out=pd.DataFrame(list(zip(MyBeta,MySE*100,tstat,pvalue*100)))
            out.index=xlist
            out.columns=['Estimate','StandErr','tStat','pValue']      
            end_time=time.time();elapse_time=end_time-start_time            

            if n==0: continue
            if not disp: continue
            IPython.display.clear_output(wait=True)
            pd.set_option('display.float_format', lambda x: '%.3f' % x)
            pd.set_option('display.max_columns',100)
            pd.set_option('display.width',100)
            progress = np.round((n+1)/niter*100, 2)

            '''output to screen for interactive analysis'''
            if tv:
                print('Time elapsed:',('%.1f'%elapse_time),'seconds',end=' ')
                print('with averaged subsample sizes', ('%.1f'%(NS/(n+1)))) 
                print('Task accomplished: ', ('%.1f'%progress),'% for a total of ', niter, 'random replications.')
                print('The out-of-sample AUC = ',('%.1f'%AUC),'%')
                print('');print(np.round(out, 3));print('')
                print('* Estimate: the ordinary least squares estimate for standardized x variables.')
                print('* StandErr: the standard error in %; t.Stat: the t-statistics = Estimate/Stand.Err.')
                print('* pValue: the p-value computed by normal approximation in %.')
            else:
                fig,ax = plt.subplots(); fig.set_figwidth(15); fig.set_figheight(7.5)
                tmpout=out;yval=list(tmpout.tStat); xval=list(tmpout.index)
                titlestr='Task accomplished = ['+str(n+1)+'/'+str(niter)+']; '
                titlestr=titlestr+'Time elapsed = ['+('%.1f'%elapse_time)+'] seconds;'
                titlestr=titlestr+'  out-of-sample AUC = ['+('%.1f'%AUC)+'%].'
                ax.set_title(titlestr)
                ax.set_ylabel(y)
                ax.set_xlabel('tStat = averaged mle estimates / standard error')
                colors = cm.rainbow(1-tmpout['pValue'])
                sns.barplot(yval,xval,palette=colors)
                plt.show() 

        return out
        
