# Kornpob Bhirombhakdi
# kbhirombhakdi@stsci.edu

from hstgrism.grismapcorr import GrismApCorr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

class Contamination:
    """
    Contamination is a class handling computation for grism contamination.
    - Object = object of interest
    - Contaminate = another object in the FoV of Object and contaminating it
    - trace_object_csv = trace.csv of Object (generated by HSTGRISM flow)
    - trace_contaminate_csv = trace.csv of Contaminate
    - halfdy_object = integer for extraction aperture halfdy width of Object (excluding trace; that is total width 2+halfdy_object + 1).
     - if halfdy_object = None, contamination region would not be computed.
    - instrument = string for available instrument list in hstgrism.grismapcorr.GrismApCorr.
     - if not None, self.compute() will grab aperture correction values (apcorr) from GrismApCorr table given instrument.
     - self.apcorr to access apcorr as a dict with key as halfdy in pixel unit and value as 1D array parallel to self.combine_df.ww_contaminate
       - Note: for ww_contaminate where it is None, we replaced it with ww_contaminate.min(). This would not cause any conflict at the end since we will place None back.
     - self.combine_df after running self.compute() will also have columns corresponding to apcorr_contaminate_in and apcorr_contaminate_out.
    Use self.compute() to start computation after properly specified information.
    - if halfdy_object is not None, contamination region would be computed.
    Use self.save(container) to save outputs to ./savefolder/saveprefix_savesuffix.extension
    - savefolder, saveprefix is controlled by Container class (see hstgrism.container.Container)
    - savesuffix.extension is preset:
     - comtaminateTrace.csv = merge trace of Contaminate to Object trace frame, accessible by self.combine_df
    Use self.show(save,container) to show traces, extraction region of Object, and contamination region
    - if save=True, the plot is saved at ./savefolder/saveprefix_contamination.plotformat where savefolder, saveprefix, and plotformat are defined by container (see hstgrism.container.Container)
    """
    def __init__(self,trace_object_csv,trace_contaminate_csv,halfdy_object=None,instrument=None):
        self.trace_object = pd.read_csv(trace_object_csv)
        self.trace_contaminate = pd.read_csv(trace_contaminate_csv)
        self.halfdy_object = halfdy_object
    def compute(self):
        xg_object = np.ceil(self.trace_object.xh.values + self.trace_object.xyref[0]).astype(int)
        yg_object = np.ceil(self.trace_object.yh.values + self.trace_object.xyref[1]).astype(int)
        ww_object = self.trace_object.ww.values
        xg_contaminate = np.ceil(self.trace_contaminate.xh.values + self.trace_contaminate.xyref[0]).astype(int)
        yg_contaminate = np.ceil(self.trace_contaminate.yh.values + self.trace_contaminate.xyref[1]).astype(int)
        ww_contaminate = self.trace_contaminate.ww.values
        yg_contaminate_wrt_object = np.array([],dtype=object)
        ww_contaminate_wrt_object = np.array([],dtype=object)
        for ii,i in enumerate(xg_object):
            if i in xg_contaminate:
                m = np.argwhere(xg_contaminate == i).flatten()
                yg_contaminate_wrt_object = np.concatenate((yg_contaminate_wrt_object,yg_contaminate[m]))
                ww_contaminate_wrt_object = np.concatenate((ww_contaminate_wrt_object,ww_contaminate[m]))
            else:
                yg_contaminate_wrt_object = np.concatenate((yg_contaminate_wrt_object,[None]))
                ww_contaminate_wrt_object = np.concatenate((ww_contaminate_wrt_object,[None]))
        combine_df = {'xg_object':xg_object,'yg_object':yg_object,'ww_object':ww_object,'yg_contaminate':yg_contaminate_wrt_object,'ww_contaminate':ww_contaminate_wrt_object}
        combine_df = pd.DataFrame(combine_df)
        self.combine_df = combine_df  
        ##### make halfdy_contaminate_in and out #####
        if self.halfdy_object is not None:
            tx = self.combine_df.xg_object.values.copy()
            ty = self.combine_df.yg_object.values.copy()
            tyc = self.combine_df.yg_contaminate.copy()
            halfdy_contaminate_in = np.full_like(tx,None,dtype=object)
            halfdy_contaminate_out = np.full_like(tx,None,dtype=object)
            for ii,i in enumerate(tx):
                if tyc[ii] is not None:
                    halfdy_contaminate_in[ii] = np.abs(tyc[ii] - ty[ii]) - self.halfdy_object - 1
                    halfdy_contaminate_out[ii] = np.abs(tyc[ii] - ty[ii]) + self.halfdy_object
            self.combine_df['halfdy_contaminate_in'] = halfdy_contaminate_in.copy()
            self.combine_df['halfdy_contaminate_out'] = halfdy_contaminate_out.copy()  
        ##### prepare apcorr #####
        if instrument is not None:
            halfdy_list = np.concatenate((self.combine_df.halfdy_contaminate_in.unique(),self.combine_df.halfdy_contaminate_out.unique()))
            apcorr = {}
            for i in halfdy_list:
                if i is None:
                    continue
                halfdy_contaminate = i
                apsizepix = halfdy_contaminate*2 + 1
                apsizepix = np.full_like(self.combine_df.ww_contaminate.values,apsizepix,dtype=float)
                apsizearcsec = apsizepix * GrismApCorr().table[instrument]['scale']
                tww = self.combine_df.ww_contaminate.values.astype(float).copy()
                tww[np.isnan(tww)] = tww[np.isfinite(tww)].min()
                apcorrobj = GrismApCorr(instrument=instrument,apsize=apsizearcsec,wave=tww,aptype='diameter',apunit='arcsec',waveunit='A')
                apcorrobj.compute()
                apcorr[i] = apcorrobj.data['apcorr'] 
            self.apcorr = apcorr
            ##### prepare apcorr_contaminate_in and apcorr_contaminate_out #####
            ##### place them into combine_df #####
            tx = self.combine_df.halfdy_contaminate_in.values.copy()
            apcorr_contaminate_in = np.array([],dtype=object)
            for ii,i in enumerate(tx):
                if i is None:
                    apcorr_contaminate_in = np.concatenate((apcorr_contaminate_in,np.array([None])))
                else:
                    apcorr_contaminate_in = np.concatenate((apcorr_contaminate_in,np.array([self.apcorr[i][ii]])))
            tx = combine_df.halfdy_contaminate_out.values.copy()
            apcorr_contaminate_out = np.array([],dtype=object)
            for ii,i in enumerate(tx):
                if i is None:
                    apcorr_contaminate_out = np.concatenate((apcorr_contaminate_out,np.array([None])))
                else:
                    apcorr_contaminate_out = np.concatenate((apcorr_contaminate_out,np.array([self.apcorr[i][ii]])))
            self.combine_df['apcorr_contaminate_in'] = apcorr_contaminate_in
            self.combine_df['apcorr_contaminate_out'] = apcorr_contaminate_out
        
        
        
        
        
        
        
    def save(self,container=None):
        if container is None:
            raise ValueError('container must be specified to save.')
        string = './{0}/{1}_contaminateTrace.csv'.format(container.data['savefolder'],container.data['saveprefix'])
        self.combine_df.to_csv(string)
        print('Save {0}'.format(string))
    def show(self,figsize=(10,10),
             object_color='black',object_ls=':',object_marker='x',
             contaminate_color='red',contaminate_ls=':',contaminate_marker='x',
             fontsize=12,
             save=False,container=None,
            ):
        plt.figure(figsize=figsize)
        plt.plot(combine_df.xg_object,combine_df.yg_object,ls=object_ls,color=object_color,label='Object trace')
        plt.plot(combine_df.xg_object,combine_df.yg_object + halfdy_grb,color=object_color,marker=object_marker)
        plt.plot(combine_df.xg_object,combine_df.yg_object - halfdy_grb,color=object_color,marker=object_marker)
        plt.plot(combine_df.xg_object,combine_df.yg_contaminate,color=contaminate_color,ls=contaminate_ls,label='Contaminate trace')
        plt.plot(combine_df.xg_object,combine_df.yg_contaminate - combine_df.halfdy_contaminate_in,color=contaminate_color,marker=contaminate_marker)
        plt.plot(combine_df.xg_object,combine_df.yg_contaminate - combine_df.halfdy_contaminate_out,color=contaminate_color,marker=contaminate_marker)
        plt.xlabel('pixX',fontsize=fontsize)
        plt.ylabel('pixY',fontsize=fontsize)
        plt.title('Contamination',fontsize=fontsize)
        plt.legend(loc=(1.01,0.))
        plt.tight_layout()
        if save:
            if container is None:
                raise ValueError('container must be specified to save')
            string = './{0}/{1}_contamination.{2}'.format(container.data['savefolder'],container.data['saveprefix'],container.data['plotformat'])
            plt.savefig(string,plotformat=container.data['plotformat'],bbox_inches='tight')
            print('Save {0}'.format(string))          
            