#!/usr/bin/env python

"""

Calculate masses of clusters detected by nemo

Requires 'massOptions' in nemo config file

Run this after filtering / detecting objects in a map with nemo itself

Can be used to obtain 'forced photometry', i.e., mass estimates for objects in redshiftCatalog
(for, e.g., optical stacking)

"""

import os
import sys
import numpy as np
import pylab as plt
import astropy.table as atpy
from astLib import *
from scipy import stats
from scipy import interpolate
import nemo
from nemo import catalogs
from nemo import signals
from nemo import maps
from nemo import filters
from nemo import MockSurvey
from nemo import photometry
from nemo import startUp
from nemo import completeness
from nemo import pipelines
import argparse
import astropy.io.fits as pyfits
import time
import yaml
#import IPython

#------------------------------------------------------------------------------------------------------------
def addForcedPhotometry(pathToCatalog, config, zColumnName = None, zErrColumnName = None):
    """Given the path to a catalog that contains minimal columns name, RADeg, decDeg, redshift, 
    add forced photometry columns so we can estimate SZ masses.
    
    """
    
    print(">>> Doing forced photometry ...")
    
    zTab=atpy.Table().read(pathToCatalog)
    if zColumnName is not None:
        if 'redshift' in zTab.keys():
            zTab.remove_column('redshift')
        zTab.rename_column(zColumnName, 'redshift')
    if zErrColumnName is not None:
        if 'redshiftErr' in zTab.keys():
            zTab.remove_column('redshiftErr')
        zTab.rename_column(zErrColumnName, 'redshiftErr')
    if 'redshift' not in zTab.keys():
        foundRedshiftCol=False
        possZCols=['z', 'Z', 'REDSHIFT', 'Redshift', 'z_cl', 'Photz']
        for p in possZCols:
            if p in zTab.keys():
                print("... assuming %s is the redshift column ..." % (p))
                foundRedshiftCol=True
                zTab.rename_column(p, 'redshift')
        if foundRedshiftCol == False:
            raise Exception("Couldn't find a redshift column in %s" % (pathToCatalog))
    if 'redshiftErr' not in zTab.keys():
        possZErrCols=['zErr', 'dz']
        for p in possZErrCols:
            if p in zTab.keys():
                print("... assuming %s is the redshiftErr column ..." % (p))
                zTab.rename_column(p, 'redshiftErr')
    if 'redshiftErr' not in zTab.keys():
        print("... assuming redshiftErr = 0 for all objects (no suitable redshiftErr column found) ...")
        zTab.add_column(atpy.Column(np.zeros(len(zTab)), 'redshiftErr'))
        
    config.parDict['forcedPhotometryCatalog']=pathToCatalog
    
    # We need to disable any cut in SNR, otherwise we're not doing truly forced photometry
    # This will allow -ve fixed_y_c, but we'll still only get the +ve masses though
    config.parDict['thresholdSigma']=-100
    
    # Trim all filters except the reference one, then make catalog
    trimmedList=[]
    for mapFilter in config.parDict['mapFilters']:
        if mapFilter['label'] == config.parDict['photFilter']:
            trimmedList.append(mapFilter)
    config.parDict['mapFilters']=trimmedList
    forcedTab=pipelines.filterMapsAndMakeCatalogs(config)

    # Need to graft the redshifts back on
    zMatched, forcedMatched, rDeg=catalogs.crossMatch(zTab, forcedTab)
    tab=forcedMatched
    tab.add_column(zMatched['redshift'])
    tab.add_column(zMatched['redshiftErr'])
    
    return tab

#------------------------------------------------------------------------------------------------------------
def calcMass(tab, massOptions, tckQFitDict, fRelWeightsDict, mockSurvey):
    """Calculates masses for cluster data in table.
    
    """

    if 'relativisticCorrection' not in massOptions.keys():
        massOptions['relativisticCorrection']=True
        
    label=mockSurvey.mdefLabel
    colNames=['%s' % (label), '%sUncorr' % (label)]
    if 'rescaleFactor' in massOptions.keys():
        colNames.append('%sCal' % (label))
    colNames=colNames+["M200m", "M200mUncorr"] # We will generalize fully later so user can choose mass defs
    for c in colNames:
        tab['%s' % (c)]=np.zeros(len(tab))
        tab['%s_errPlus' % (c)]=np.zeros(len(tab))
        tab['%s_errMinus' % (c)]=np.zeros(len(tab))
    
    count=0
    for row in tab:
        count=count+1
        print("... rank %d; %d/%d; %s (%.3f +/- %.3f) ..." % (config.rank, count, len(tab), row['name'], 
                                                              row['redshift'], row['redshiftErr']))

        tileName=row['tileName']
        
        # Cuts on z, fixed_y_c for forced photometry mode (invalid objects will be listed but without a mass)
        if row['fixed_y_c'] > 0 and np.isnan(row['redshift']) == False:
            # Corrected for mass function steepness
            massDict=signals.calcMass(row['fixed_y_c']*1e-4, row['fixed_err_y_c']*1e-4, 
                                            row['redshift'], row['redshiftErr'],
                                            tenToA0 = massOptions['tenToA0'],
                                            B0 = massOptions['B0'], 
                                            Mpivot = massOptions['Mpivot'], 
                                            sigma_int = massOptions['sigma_int'],
                                            tckQFit = tckQFitDict[tileName], mockSurvey = mockSurvey, 
                                            applyMFDebiasCorrection = True,
                                            applyRelativisticCorrection = massOptions['relativisticCorrection'],
                                            fRelWeightsDict = fRelWeightsDict[tileName])
            row['%s' % (label)]=massDict['%s' % (label)]
            row['%s_errPlus' % (label)]=massDict['%s_errPlus' % (label)]
            row['%s_errMinus' % (label)]=massDict['%s_errMinus' % (label)]
            # Uncorrected for mass function steepness
            unCorrMassDict=signals.calcMass(row['fixed_y_c']*1e-4, row['fixed_err_y_c']*1e-4, 
                                                    row['redshift'], row['redshiftErr'],
                                                    tenToA0 = massOptions['tenToA0'],
                                                    B0 = massOptions['B0'], 
                                                    Mpivot = massOptions['Mpivot'], 
                                                    sigma_int = massOptions['sigma_int'],
                                                    tckQFit = tckQFitDict[tileName], mockSurvey = mockSurvey, 
                                                    applyMFDebiasCorrection = False,
                                                    applyRelativisticCorrection = massOptions['relativisticCorrection'],
                                                    fRelWeightsDict = fRelWeightsDict)
            row['%sUncorr' % (label)]=unCorrMassDict['%s' % (label)]
            row['%sUncorr_errPlus' % (label)]=unCorrMassDict['%s_errPlus' % (label)]
            row['%sUncorr_errMinus' % (label)]=unCorrMassDict['%s_errMinus' % (label)]
            # Mass conversion -  remove later when fully generalized
            row['M200m']=signals.convertM500cToM200m(massDict['M500c']*1e14, row['redshift'])/1e14
            row['M200m_errPlus']=(row['M500c_errPlus']/row['M500c'])*row['M200m']
            row['M200m_errMinus']=(row['M500c_errMinus']/row['M500c'])*row['M200m']
            row['M200mUncorr']=signals.convertM500cToM200m(unCorrMassDict['M500c']*1e14, row['redshift'])/1e14
            row['M200mUncorr_errPlus']=(row['M500cUncorr_errPlus']/row['M500cUncorr'])*row['M200mUncorr']
            row['M200mUncorr_errMinus']=(row['M500cUncorr_errMinus']/row['M500cUncorr'])*row['M200mUncorr']
            # Re-scaling (e.g., using richness-based weak-lensing mass calibration)
            if 'rescaleFactor' in massOptions.keys():
                row['%sCal' % (label)]=massDict['%s' % (label)]/massOptions['rescaleFactor']        
                row['%sCal_errPlus' % (label)]=np.sqrt(np.power(row['%s_errPlus' % (label)]/row['%s' % (label)], 2) + \
                                                       np.power(massOptions['rescaleFactorErr']/massOptions['rescaleFactor'], 2))*row['%sCal' % (label)]
                row['%sCal_errMinus' % (label)]=np.sqrt(np.power(row['%s_errMinus' % (label)]/row['%s' % (label)], 2) + \
                                                        np.power(massOptions['rescaleFactorErr']/massOptions['rescaleFactor'], 2))*row['%sCal' % (label)]
            #row['M200mCal']=signals.convertM500cToM200m(row['M500Cal']*1e14, row['redshift'])/1e14
            #row['M200mCal_errPlus']=(row['M500Cal_errPlus']/row['M500Cal'])*row['M200mCal']
            #row['M200mCal_errMinus']=(row['M500Cal_errMinus']/row['M500Cal'])*row['M200mCal']

    return tab

#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser=argparse.ArgumentParser("nemoMass")
    parser.add_argument("configFileName", help="""A .yml configuration file.""")
    parser.add_argument("-c", "--catalog", dest="catFileName", help = """Catalog file name (.fits format).
                        The catalog must contain at least the following columns: name, RADeg, decDeg, 
                        redshift, redshiftErr. If the catalog contains fixed_y_c, fixed_err_y_c columns, 
                        then these will be used to infer mass estimates. If not, "forced photometry" mode 
                        will be enabled, and the fixed_y_c, fixed_err_y_c values will be extracted from the
                        filtered maps.""", default = None)
    parser.add_argument("-o", "--output", dest="outFileName", help = """Output catalog file name 
                        (.fits format). If not given, the name of the output catalog file will be based on
                        either configFileName or catFileName.""", default = None)
    parser.add_argument("-z", "--z-column", dest="zColumnName", help = """Specify the name of the redshift
                        column in the input catalog.""", default = None)
    parser.add_argument("-e", "--z-error-column", dest="zErrColumnName", help = """Specify the name of the 
                        redshift uncertainty column in the input catalog.""", default = None)
    parser.add_argument("-F", "--forced-photometry", dest="forcedPhotometry", help = """Do forced photometry.
                        This is automatically enabled if the catalog does not contain the fixed_y_c, 
                        fixed_err_y_c columns. Use this switch to force using this mode even if the 
                        catalog already contains fixed_y_c, fixed_err_y_c  columns (e.g., for doing forced
                        photometry on one ACT map using positions of clusters found in another, deeper 
                        map).""", default = False, action="store_true")
    parser.add_argument("-M", "--mpi", dest="MPIEnabled", action="store_true", help="""Enable MPI. If you 
                        want to use this, run with e.g., mpiexec -np 4 nemoMass configFile.yml -M""", 
                        default = False)
    parser.add_argument("-n", "--no-strict-errors", dest="noStrictMPIExceptions", action="store_true",
                        help="""Disable strict exception handling (applies under MPI only, i.e., must be
                        used with the -M switch). If you use this option, you will get the full traceback
                        when a Python Exception is triggered, but the code may not terminate. This is due
                        to the Exception handling in mpi4py.""",
                        default = False)
    args = parser.parse_args()
    
    parDictFileName=args.configFileName
    catFileName=args.catFileName
    outFileName=args.outFileName
    forcedPhotometry=args.forcedPhotometry
    
    if args.noStrictMPIExceptions == True:
        strictMPIExceptions=False
    else:
        strictMPIExceptions=True
    
    # Load the nemo catalog and match against the z catalog
    # NOTE: This is now done using coord matching (nearest within some maximum tolerance), rather than names
    if catFileName is None:
        config=startUp.NemoConfig(parDictFileName, MPIEnabled = args.MPIEnabled, divideTilesByProcesses = False,
                                  setUpMaps = False, verbose = False, strictMPIExceptions = strictMPIExceptions)
        optimalCatalogFileName=config.rootOutDir+os.path.sep+"%s_optimalCatalog.fits" % (os.path.split(config.rootOutDir)[-1])           
        nemoTab=atpy.Table().read(optimalCatalogFileName)
        zTab=atpy.Table().read(config.parDict['massOptions']['redshiftCatalog'])
        nemoTab, zTab, rDeg=catalogs.crossMatch(nemoTab, zTab, radiusArcmin = 2.5)
        nemoTab['redshift']=zTab['redshift']
        nemoTab['redshiftErr']=zTab['redshiftErr']
        tab=nemoTab
        if outFileName is None:
            outFileName=optimalCatalogFileName.replace("_optimalCatalog.fits", "_mass.fits")

    else:
        
        # Load another catalog (e.g., a mock, for testing)
        optimalCatalogFileName=catFileName
        tab=atpy.Table().read(optimalCatalogFileName)
        if outFileName is None:
            outFileName=catFileName.replace(".fits", "_mass.fits")
        
        # Enter forced photometry mode if we can't find the columns we need
        # If we're doing forced photometry, we're going to need to set-up so we can find the filtered maps
        keysNeeded=['fixed_y_c', 'fixed_err_y_c']
        for key in keysNeeded:
            if key not in tab.keys():
                forcedPhotometry=True
        config=startUp.NemoConfig(parDictFileName, MPIEnabled = args.MPIEnabled, divideTilesByProcesses = False,
                                  setUpMaps = forcedPhotometry, writeTileDir = False, verbose = False,
                                  strictMPIExceptions = strictMPIExceptions)
    
    # Remaining set-up
    massOptions=config.parDict['massOptions']
    tckQFitDict=signals.loadQ(config)    
    fRelWeightsDict=signals.getFRelWeights(config)
    
    # Forced photometry (if enabled) - modifying table in place
    # NOTE: Move this up if/when we make it run under MPI
    if forcedPhotometry == True:
        tab=addForcedPhotometry(catFileName, config, args.zColumnName, args.zErrColumnName)
    
    # Set cosmological parameters for e.g. E(z) calc if these are set in .par file
    # We set them after the Q calc, because the Q calc needs to be for the fiducial cosmology
    # (OmegaM0 = 0.3, OmegaL0 = 0.7, H0 = 70 km/s/Mpc) used in object detection/filtering stage
    # Set-up the mass function stuff also
    # This is for correction of mass bias due to steep cluster mass function
    # Hence minMass here needs to be set well below the survey mass limit
    # areaDeg2 we don't care about here
    minMass=1e13
    areaDeg2=700.
    zMin=0.0
    zMax=2.0
    # H0, Om0, Ol0 used for E(z), theta500 calcs in Q - these are updated when we call create mockSurvey
    if 'H0' in list(massOptions.keys()):
        H0=massOptions['H0']
    else:
        H0=70.
    if 'Om0' in list(massOptions.keys()):
        Om0=massOptions['Om0']
    else:
        Om0=0.3
    # OmegaB0, sigma8 only used for mass function Eddington bias correction
    if 'Ob0' in list(massOptions.keys()):
        Ob0=massOptions['Ob0']
    else:
        Ob0=0.05
    if 'sigma8' in list(massOptions.keys()):
        sigma8=massOptions['sigma8']
    else:
        sigma8=0.8
    if 'ns' in list(massOptions.keys()):
        ns=massOptions['ns']
    else:
        ns=0.95
    mockSurvey=MockSurvey.MockSurvey(minMass, areaDeg2, zMin, zMax, H0, Om0, Ob0, sigma8, ns)
        
    # Seems like multiprocessing doesn't work under slurm...   
    # So use MPI instead: just divide up catalog among processes
    tab.add_column(atpy.Column(np.arange(len(tab)), "sortIndex"))
    if config.MPIEnabled == True:
        numRowsPerProcess=int(np.ceil(len(tab)/config.size))
        startIndex=config.rank*numRowsPerProcess
        endIndex=startIndex+numRowsPerProcess
        if config.rank == config.size-1:
            endIndex=len(tab)
        tab=tab[startIndex:endIndex]
        
    tab=calcMass(tab, massOptions, tckQFitDict, fRelWeightsDict, mockSurvey)
    
    if config.MPIEnabled == True:
        tabList=config.comm.gather(tab, root = 0)
        if config.rank != 0:
            assert tabList is None
            print("... MPI rank %d finished ..." % (config.rank))
            sys.exit()
        else:
            print("... gathering catalogs ...")
            tab=atpy.vstack(tabList)
    
    tab.sort('sortIndex')
    tab.remove_column('sortIndex')
    
    outDir=os.path.split(outFileName)[0]
    os.makedirs(outDir, exist_ok = True)
    tab.meta['NEMOVER']=nemo.__version__
    tab.write(outFileName, overwrite = True)
    
    # Detect if testing a mock catalog, and write some stats on recovered masses
    if 'true_M500' in tab.keys():
        # Noise sources in mocks
        if 'applyPoissonScatter' in config.parDict.keys():
            applyPoissonScatter=config.parDict['applyPoissonScatter']
        else:
            applyPoissonScatter=True
        if 'applyIntrinsicScatter' in config.parDict.keys():
            applyIntrinsicScatter=config.parDict['applyIntrinsicScatter']
        else:
            applyIntrinsicScatter=True
        if 'applyNoiseScatter' in config.parDict.keys():
            applyNoiseScatter=config.parDict['applyNoiseScatter']
        else:
            applyNoiseScatter=True
        print(">>> Mock noise sources (Poisson, intrinsic, measurement noise) = (%s, %s, %s) ..." % (applyPoissonScatter, applyIntrinsicScatter, applyNoiseScatter))
        if applyNoiseScatter == True and applyIntrinsicScatter == True:
            print("... for these options, median M500 / true_M500 = 1.000 if mass recovery is unbiased ...")
        elif applyNoiseScatter == False and applyIntrinsicScatter == False:
            print("... for these options, median M500Unc / true_M500 = 1.000 if mass recovery is unbiased ...")
        else:
            print("... for these options, both median M500 / true_M500 and median M500Unc / true_M500 will be biased ...")
        ratio=tab['M500']/tab['true_M500']
        ratioUnc=tab['M500Uncorr']/tab['true_M500']
        print(">>> Mock catalog mass recovery stats:")
        SNRCuts=[4, 5, 7, 10]
        for s in SNRCuts:
            mask=np.greater(tab['fixed_SNR'], s)
            print("--> SNR > %.1f (N = %d):" % (s, np.sum(mask))) 
            print("... mean M500 / true_M500 = %.3f +/- %.3f (stdev) +/- %.3f (sterr)" % (np.mean(ratio[mask]), 
                                                                                          np.std(ratio[mask]), 
                                                                                          np.std(ratio[mask])/np.sqrt(np.sum(mask))))
            print("... median M500 / true_M500 = %.3f" % (np.median(ratio[mask])))
            print("... mean M500Unc / true_M500 = %.3f +/- %.3f (stdev) +/- %.3f (sterr)" % (np.mean(ratioUnc[mask]), 
                                                                                          np.std(ratioUnc[mask]), 
                                                                                          np.std(ratioUnc[mask])/np.sqrt(np.sum(mask))))
            print("... median M500Unc / true_M500 = %.3f" % (np.median(ratioUnc[mask])))


