#!/usr/bin/env python

"""

Example of using nemo selection function stuff for cosmology.

NOTE: relies on global variables:
- selFn
- tab (and related stuff on a grid)

Using Cobaya as the sampler.

"""

import os
import sys
os.environ["OMP_NUM_THREADS"] = "1"                 # Needed for multiprocessing / MPI to be reliable
#print("Running under python: %s" % (sys.version))
import numpy as np
import pylab as plt
import astropy.table as atpy
from astLib import *
from scipy import stats
from scipy import interpolate
from scipy.special import factorial
from nemo import completeness
from nemo import signals
import argparse
import time
from collections import OrderedDict as odict
from cobaya.run import run
import cobaya.yaml
from getdist.mcsamples import MCSamplesFromCobaya
import getdist.plots as gdplt

# Extreme debugging (better for multiprocessing stuff to crash where we can see it)
import warnings
#warnings.filterwarnings("error")

#-------------------------------------------------------------------------------------------------------------
def lnprob_yc(H0, Om0, sigma8, Ob0, ns, tenToA0, B0, Mpivot, sigma_int):
    """Log likelihood function for use with Cobaya.
        
    """
        
    # This can fail if wander outside parameter space where mass function is defined
    try:
        selFn.update(H0, Om0, Ob0, sigma8, ns, scalingRelationDict = {'tenToA0': tenToA0, 'B0': B0, 
                                                                      'Mpivot': Mpivot, 'sigma_int': sigma_int})
    except:
        return -np.inf
    
    # Transform cluster counts to binned y0 grid
    corrGrid=selFn.mockSurvey.clusterCount*selFn.compMz
    predGrid=np.zeros(obsGrid.shape)
    for i in range(len(obs_zBinEdges)-1):
        zMin=obs_zBinEdges[i]
        zMax=obs_zBinEdges[i+1]
        zMask=np.logical_and(selFn.mockSurvey.z >= zMin, selFn.mockSurvey.z < zMax)
        zMask=np.array([zMask]*selFn.mockSurvey.clusterCount.shape[1]).transpose() 
        for j in range(len(obs_ycBinEdges)-1):
            ycMin=obs_ycBinEdges[j]
            ycMax=obs_ycBinEdges[j+1]
            ycMask=np.logical_and(selFn.y0Grid >= ycMin, selFn.y0Grid < ycMax)
            predGrid[j, i]=corrGrid[zMask*ycMask].sum()

    # Poisson probability in (log10(y0~), z) grid
    mask=np.greater(obsGrid, 0)
    if mask.sum() > 0:
        lnlike=np.sum(obsGrid[mask]*np.log(predGrid[mask])-predGrid[mask]-np.log(factorial(obsGrid[mask])))
    else:
        lnlike=-np.inf
    
    return lnlike

#-------------------------------------------------------------------------------------------------------------
def lnprob_M500(H0, Om0, sigma8, Ob0, ns, tenToA0, B0, Mpivot, sigma_int):
    """Log likelihood function for use with Cobaya.
        
    """
    
    # This can fail if wander outside parameter space where mass function is defined
    try:
        selFn.update(H0, Om0, Ob0, sigma8, ns, scalingRelationDict = {'tenToA0': tenToA0, 'B0': B0, 
                                                                      'Mpivot': Mpivot, 'sigma_int': sigma_int})
    except:
        return -np.inf
    
    # Apply completeness only to predicted counts (selection is already applied to observed)
    predMz=selFn.compMz*selFn.mockSurvey.clusterCount
    
    # Now 2d histogram with few cells - more stable?
    masses=calcMass(tab, selFn.scalingRelationDict, selFn.tckQFitDict, selFn.fRelDict, selFn.mockSurvey, verbose = False)  
    log10MBinEdges=np.linspace(np.log10(masses.min()), np.log10(masses.max()), 6)
    obsGrid, obsGrid_log10MBinEdges, obsGrid_zBinEdges=np.histogram2d(np.log10(masses), tab['redshift'], 
                                                                      bins = [log10MBinEdges, selFn.mockSurvey.zBinEdges])
    predGrid=np.zeros(obsGrid.shape)
    for i in range(len(obsGrid_log10MBinEdges)-1):
        mBinMin=obsGrid_log10MBinEdges[i]
        mBinMax=obsGrid_log10MBinEdges[i+1]
        mMask=np.logical_and(selFn.mockSurvey.log10M >= mBinMin, selFn.mockSurvey.log10M < mBinMax)
        for j in range(len(obsGrid_zBinEdges)-1):
            zBinMin=obsGrid_zBinEdges[j]
            zBinMax=obsGrid_zBinEdges[j+1]
            zMask=np.logical_and(selFn.mockSurvey.z >= zBinMin, selFn.mockSurvey.z < zBinMax)
            predGrid[i, j]=predMz[zMask][0][mMask].sum()

    # Poisson probability in (log10M, z) grid
    mask=np.logical_and(np.greater(obsGrid, 0), np.greater(predGrid, 0))
    if mask.sum() > 0:
        lnlike=np.sum(obsGrid[mask]*np.log(predGrid[mask])-predGrid[mask]-np.log(factorial(obsGrid[mask])))
    else:
        lnlike=-np.inf

    return lnlike

#------------------------------------------------------------------------------------------------------------
def calcMass(tab, massOptions, tckQFitDict, fRelWeightsDict, mockSurvey, verbose = False):
    """Calculates masses (M500 and M200m) using fixed_y_c, redshift columns in the given table.
    
    """

    masses=np.zeros(len(tab))
    for i in range(len(tab)):
        row=tab[i]
        if verbose: print("... %d/%d; %s (%.3f +/- %.3f) ..." % (count, len(tab), row['name'], 
                                                              row['redshift'], row['redshiftErr']))

        tileName=row['tileName']
        
        # Cuts on z, fixed_y_c for forced photometry mode (invalid objects will be listed but without a mass)
        if row['fixed_y_c'] > 0 and np.isnan(row['redshift']) == False:
            # Corrected for mass function steepness
            massDict=signals.calcMass(row['fixed_y_c']*1e-4, row['fixed_err_y_c']*1e-4, 
                                            row['redshift'], row['redshiftErr'],
                                            tenToA0 = massOptions['tenToA0'],
                                            B0 = massOptions['B0'], 
                                            Mpivot = massOptions['Mpivot'], 
                                            sigma_int = massOptions['sigma_int'],
                                            tckQFit = tckQFitDict[tileName], mockSurvey = mockSurvey, 
                                            applyMFDebiasCorrection = True,
                                            applyRelativisticCorrection = True,
                                            fRelWeightsDict = fRelWeightsDict[tileName],
                                            calcErrors = False)
            masses[i]=massDict['M500']
    
    return masses*1e14

#------------------------------------------------------------------------------------------------------------
def makeGetDistPlot(cosmoOutDir):
    """Makes a corner plot using GetDist.
    
    """
    
    # Plot
    import getdist.plots as gplot
    g=gplot.getSubplotPlotter(chain_dir = cosmoOutDir)
    roots=['chain']
    params=['H0', 'Om0', 'sigma8']
    param_3d = None
    g.triangle_plot(roots, params, plot_3d_with_param=param_3d, filled=True, shaded=False)
    g.export(fname=cosmoOutDir+os.path.sep+"cornerplot.pdf")
    g.export(fname=cosmoOutDir+os.path.sep+"cornerplot.png")
    
#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':
    
    defaultMaxSamples=3000
    defaultNumToBurn=50
    parser=argparse.ArgumentParser("nemoCosmo")
    parser.add_argument("catalogFileName", help = """Catalog file name, in .fits format, as produced by nemo
                        or nemoMock.""")
    parser.add_argument("selFnDir", help = """Directory containing files needed for computing the selection 
                        function.""")
    parser.add_argument("-c", "--cobaya-config", dest = "cobayaConfig", help = """A Cobaya .yml configuration 
                        file. Use the -d switch to dump the default settings to a file.""", default = None)
    parser.add_argument("-d", "--dump-default-config", dest = "dumpConfig", help = """Dumps the default Cobaya
                        configuration to a file called default.yml in the current working directory.""",
                        action = "store_true", default = False)
    parser.add_argument("-o", "--output-dir", dest = "cosmoOutDir", help = """Name of directory in which
                        to store output chains (default: cosmo_catalogFileName_SNRCut, where
                        catalogFileName is stripped of the .fits extension).""", default = None)
    parser.add_argument("-S", "--SNR-cut", dest = "SNRCut", help = """Use only clusters with fixed_SNR > 
                        this value.""", default = 5.0, type = float)
    parser.add_argument("-f", "--footprint", dest = "footprint", help="""Footprint to use, e.g., DES,
                        HSC, KiDS (default: full). Note that the catalog will not be trimmed to match 
                        the footprint.""", default = None)
    parser.add_argument("-R", "--relativistic-correction", dest = "relativisticCorrection", 
                        action = "store_true", default = False, help = """Apply relativistic correction
                        to the signal modelling and completeness calculation.""")
    parser.add_argument("-D", "--mass-def", dest = "massDef", default = "M500c", help = """Mass 
                        definition to use (e.g., M500c or M200m).""")
    parser.add_argument("-m", "--max-samples", dest="maxSamples", help="""Maximum number of samples. If 
                        given, overrides the value given in the Cobaya configuration file.""", type = int,
                        default = defaultMaxSamples)
    parser.add_argument("-b", "--burn", dest = "numToBurn", help = """Number of samples to burn. If given,
                        overrides the value given in the Cobaya configuration file.""", 
                        default = defaultNumToBurn, type = int)
    parser.add_argument("-t", "--test-likelihood", dest = "testLikelihood", help = """Run a test of the
                        likelihood, varying one parameter at a time, instead of running Cobaya,
                        producing diagnostic plots called 'testLikelihood_paramName.png' in the current
                        working directory. This test assumes you are running on a mock catalog with 
                        cosmology and scaling relation parameters that match the WebSky simulations.""",
                        action = "store_true", default = False)
    parser.add_argument("-L", "--likelihood", dest = "likelihoodType", help = """The likelihood function
                        to use. Options are 'mass', or 'yc' (latter is even more experimental).""", 
                        default = "mass")
    args = parser.parse_args()
    
    tabFileName=args.catalogFileName
    selFnDir=args.selFnDir
    maxSamples=args.maxSamples
    SNRCut=args.SNRCut
    footprintLabel=args.footprint
    cosmoOutDir=args.cosmoOutDir
    numToBurn=args.numToBurn
    massDef=args.massDef
    if massDef == "M500c":
        rhoType="critical"
        delta=500
    elif massDef == "M200m":
        rhoType="matter"
        delta=200
    else:
        raise Exception("massDef should be either M500c or M200m (use -D M500c or -D M200m)")
    
    if args.likelihoodType == "mass":
        lnprob=lnprob_M500
    elif args.likelihoodType == "yc":
        lnprob=lnprob_yc
    else:
        raise Exception("Didn't understand likelihoodType - use either 'mass' or 'yc'")
        
    tab=atpy.Table().read(tabFileName)
    if 'redshift' not in tab.keys():
        raise Exception("no 'redshift' column in catalog")
        
    # We'll label output according to catalog file name (in case we want to run same settings on real and mocks)
    if cosmoOutDir == None:
        cosmoOutDir="cosmo_"+os.path.split(tabFileName)[-1].replace(".fits", "")+"_%.2f" % (SNRCut)
        
    # Cobaya set-up
    if args.cobayaConfig is None:
        info={'sampler': {'mcmc': {'burn_in': defaultNumToBurn, 
                                   'max_samples': defaultMaxSamples, 
                                   'max_tries': np.inf}}}
        info['params']=odict([['Ob0', 0.05],
                              ['tenToA0',4.95e-5],
                              ['B0', 0.08],
                              ['Mpivot', 3.0e+14],
                              ['sigma_int', 0.2],
                              ['H0', {'prior': {'dist': 'norm', 'loc': 70.0, 'scale': 4.0},
                                      'proposal': 5.0,
                                      'latex': 'H_0'}],
                              ['Om0', {'prior': {'min': 0.1, 'max': 0.5}, 
                                       'ref': {'dist': 'norm', 'loc': 0.3, 'scale': 0.1},
                                       'proposal': 0.05,
                                       'latex': '\Omega_{\\rm m0}'}],
                              ['sigma8', {'prior': {'min': 0.6, 'max': 0.9}, 
                                          'ref': {'dist': 'norm', 'loc': 0.8, 'scale': 0.1},
                                           'proposal': 0.02,
                                           'latex': '\sigma_8'}],
                              ['ns', 0.95]])
        if args.dumpConfig == True:
            cobaya.yaml.yaml_dump_file("default.yml", info, error_if_exists = False)
            print("Dumped default Cobaya configuration to `default.yml`")
            sys.exit()
    else:
        info=cobaya.yaml.yaml_load_file(args.cobayaConfig)
    
    info['likelihood']={'tsz-n': lnprob}
    info['output']=cosmoOutDir+os.path.sep+"chains"
    info['timing']=True
    info['debug']=True
    info['resume']=True
    
    if numToBurn != defaultNumToBurn:
        info['sampler']['mcmc']['burn_in']=numToBurn
    if maxSamples != defaultMaxSamples:
        info['sampler']['mcmc']['max_samples']=maxSamples

    print("Setting up SNR > %.2f selection function (footprint: %s)" % (SNRCut, footprintLabel))
    if args.relativisticCorrection == True:
        print("Relativistic correction will be applied")
    else:
        print("Relativistic correction neglected")
    selFn=completeness.SelFn(selFnDir, SNRCut, footprintLabel = footprintLabel, zStep = 0.1,
                             enableDrawSample = True, mockOversampleFactor = 1, downsampleRMS = True, 
                             applyRelativisticCorrection = args.relativisticCorrection,
                             delta = delta, rhoType = rhoType)
    
    # Cut catalog according to given SNR cut and survey footprint
    tab=tab[np.where(tab['fixed_SNR'] > SNRCut)]
    inMask=selFn.checkCoordsInAreaMask(tab['RADeg'], tab['decDeg'])
    tab=tab[inMask]
    print("%d clusters with fixed_SNR > %.1f in footprint %s" % (len(tab), SNRCut, str(footprintLabel)))

    # NOTE: Since we use such coarse z binning, we completely ignore photo-z errors here
    # Otherwise, signals.calcPMass throws exception when z grid is not fine enough
    tab['redshiftErr'][:]=0.0
   
    # Define binning on (log10(fixed_y_c), redshift) grid
    log10ycBinEdges=np.linspace(np.log10(1e-4*tab['fixed_y_c'].min()*0.8), np.log10(1e-4*tab['fixed_y_c'].max()*1.1), 6)
    obsGrid, obs_log10ycBinEdges, obs_zBinEdges=np.histogram2d(np.log10(1e-4*tab['fixed_y_c']), tab['redshift'], 
                                                               bins = [log10ycBinEdges, selFn.mockSurvey.zBinEdges])
    obs_ycBinEdges=np.power(10, obs_log10ycBinEdges)
    
    # Testing
    if args.testLikelihood == True:

        def makeTestPlot(pRange, probs, var, label):
            plt.plot(pRange, probs)
            minP=min(probs)*1.1
            maxP=max(probs)*0.9
            plt.plot([var]*10, np.linspace(minP, maxP, 10), 'k--')
            plt.xlabel(label)
            plt.ylabel("lnprob")
            plt.ylim(minP, maxP)
            plt.savefig("testLikelihood_%s_%s.png" % (label, args.likelihoodType))
            plt.close()
        
        H0, Om0, Ob0, sigma8, ns = 68.0, 0.31, 0.049, 0.81, 0.965       # WebSky cosmology
        #tenToA0, B0, Mpivot, sigma_int = 2.65e-05, 0.0, 3.0e+14, 0.2    # WebSky scaling relation - assumed scatter
        if massDef == 'M500c':
            tenToA0, B0, Mpivot, sigma_int = 3.02e-05, 0.0, 3.0e+14, 0.0    # WebSky scaling relation - no scatter, M500c
        elif massDef == 'M200m':
            tenToA0, B0, Mpivot, sigma_int = 1.7e-05, 0.0, 3.0e+14, 0.0    # WebSky scaling relation - no scatter, M200m
        #lnprob(H0, Om0, sigma8, Ob0, ns, p, B0, Mpivot, sigma_int)
        testsToRun=['H0', 'Om0', 'sigma8', 'tenToA0', 'B0', 'sigma_int']
        #testsToRun=['tenToA0', 'B0', 'sigma_int']
        # For checking if recovered parameters beat the theoretical max likelihood ones
        # If they do, we have bugs to fix...
        theoreticalMaxLogLike=lnprob(H0, Om0, sigma8, Ob0, ns, tenToA0, B0, Mpivot, sigma_int)
        print("Theoretical max log likelihood = %.3f" % (theoreticalMaxLogLike))
        print("Num clusters expected = %1f" % ((selFn.compMz*selFn.mockSurvey.clusterCount).sum()))
        print("Num clusters in catalog = %d" % (len(tab)))
        mode=input("Drop into interactive mode [y] or run likelihood tests for each parameter [n] ? ")
        if mode == "y":
            import IPython
            IPython.embed()
            sys.exit()
            testLogLike=lnprob(70.5, 0.304, 0.825, Ob0, ns, 2.78e-05, 0.306, Mpivot, 0.0) # Any test parameters
        if 'H0' in testsToRun:
            label="H0"; var=H0
            pRange=np.linspace(60, 80, 21)
            probs=[]
            count=0
            t0=time.time()
            print("%s:" % (label))
            for p in pRange:
                count=count+1
                print("    %d/%d" % (count, len(pRange)))
                probs.append(lnprob(p, Om0, sigma8, Ob0, ns, tenToA0, B0, Mpivot, sigma_int))
            t1=time.time()
            print("time per step = %.3f" % ((t1-t0)/len(probs)))
            print("diff (input - max likelihood) = %.2f" % (var-pRange[np.argmax(probs)]))
            makeTestPlot(pRange, probs, var, label)
        if 'Om0' in testsToRun:
            label="Om0"; var=Om0
            pRange=np.linspace(0.2, 0.4, 21)
            probs=[]
            count=0
            t0=time.time()
            print("%s:" % (label))
            for p in pRange:
                count=count+1
                print(count, len(pRange))
                probs.append(lnprob(H0, p, sigma8, Ob0, ns, tenToA0, B0, Mpivot, sigma_int))
            t1=time.time()
            print("time per step = %.3f" % ((t1-t0)/len(probs)))
            print("diff (input - max likelihood) = %.2f" % (var-pRange[np.argmax(probs)]))
            makeTestPlot(pRange, probs, var, label)
        if 'sigma8' in testsToRun:
            label="sigma8"; var=sigma8
            pRange=np.linspace(0.7, 0.9, 21)
            probs=[]
            count=0
            t0=time.time()
            print("%s:" % (label))
            for p in pRange:
                count=count+1
                print(count, len(pRange))
                probs.append(lnprob(H0, Om0, p, Ob0, ns, tenToA0, B0, Mpivot, sigma_int))
            t1=time.time()
            print("time per step = %.3f" % ((t1-t0)/len(probs)))
            print("diff (input - max likelihood) = %.2f" % (var-pRange[np.argmax(probs)]))
            makeTestPlot(pRange, probs, var, label)
        if 'tenToA0' in testsToRun:
            label="tenToA0"; var=tenToA0
            pRange=np.linspace(2.0e-5, 6.0e-05, 21)
            probs=[]
            count=0
            t0=time.time()
            print("%s:" % (label))
            for p in pRange:
                count=count+1
                print(count, len(pRange))
                probs.append(lnprob(H0, Om0, sigma8, Ob0, ns, p, B0, Mpivot, sigma_int))
            t1=time.time()
            print("time per step = %.3f" % ((t1-t0)/len(probs)))
            print("diff (input - max likelihood) = %.2f" % (var-pRange[np.argmax(probs)]))
            makeTestPlot(pRange, probs, var, label)
        if 'B0' in testsToRun:
            label="B0"; var=B0
            pRange=np.linspace(0.0, 0.5, 21)
            probs=[]
            count=0
            t0=time.time()
            print("%s:" % (label))
            for p in pRange:
                count=count+1
                print(count, len(pRange))
                probs.append(lnprob(H0, Om0, sigma8, Ob0, ns, tenToA0, p, Mpivot, sigma_int))
            t1=time.time()
            print("time per step = %.3f" % ((t1-t0)/len(probs)))
            print("diff (input - max likelihood) = %.2f" % (var-pRange[np.argmax(probs)]))
            makeTestPlot(pRange, probs, var, label)
        if 'sigma_int' in testsToRun:
            label="sigma_int"; var=sigma_int
            pRange=np.linspace(0.0, 0.5, 21)
            probs=[]
            count=0
            t0=time.time()
            print("%s:" % (label))
            for p in pRange:
                count=count+1
                print(count, len(pRange))
                probs.append(lnprob(H0, Om0, sigma8, Ob0, ns, tenToA0, B0, Mpivot, p))
            t1=time.time()
            print("time per step = %.3f" % ((t1-t0)/len(probs)))
            print("diff (input - max likelihood) = %.2f" % (var-pRange[np.argmax(probs)]))
            makeTestPlot(pRange, probs, var, label)
        sys.exit()
        
    # Run Cobaya...
    updated_info, products=run(info)

