#!/usr/bin/env python

"""

nemo driver script: for filtering maps and finding clusters

"""

import sys
#print("Running under python: %s" % (sys.version))
import os
import datetime
from nemo import *
import nemo
from nemo import MockSurvey
import argparse
import astropy
import astropy.table as atpy
import astropy.io.fits as pyfits
from astLib import astWCS
import numpy as np
import pylab
import pickle
import types
import yaml
#import IPython
pylab.matplotlib.interactive(False)
plotSettings.update_rcParams()

#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser=argparse.ArgumentParser("nemo")
    parser.add_argument("configFileName", help="""A .yml configuration file.""")
    parser.add_argument("-S", "--calc-selection-function", dest="calcSelFn", action="store_true",
                        help="""Calculate the completeness in terms of cluster mass, assuming the scaling
                        relation parameters given in the .yml config file. Output will be written under the
                        nemoOutput/selFn directory. This switch overrides the calcSelFn parameter in the
                        .yml config file.""", default = False)
    parser.add_argument("-I", "--run-source-injection-test", dest="sourceInjectionTest",
                        action="store_true", help="""Run a source injection test, using the settings given
                        in the .yml config file. Output will be written under the nemoOutput/diagnostics
                        (raw data) and nemoOutput/selFn directories (position recovery model fits).
                        This switch overrides the sourceInjectionTest parameter in the .yml config
                        file.""", default = False)
    parser.add_argument("-M", "--mpi", dest="MPIEnabled", action="store_true", help="""Enable MPI. If you 
                        want to use this, run with e.g., mpiexec -np 4 nemo configFile.yml -M""", 
                        default = False)
    parser.add_argument("-n", "--no-strict-errors", dest="noStrictMPIExceptions", action="store_true", 
                        help="""Disable strict exception handling (applies under MPI only, i.e., must be
                        used with the -M switch). If you use this option, you will get the full traceback
                        when a Python Exception is triggered, but the code may not terminate. This is due
                        to the Exception handling in mpi4py.""", 
                        default = False)
    args=parser.parse_args()
    
    if args.noStrictMPIExceptions == True:
        strictMPIExceptions=False
    else:
        strictMPIExceptions=True
    
    parDictFileName=args.configFileName
    config=startUp.NemoConfig(parDictFileName, calcSelFn = args.calcSelFn,
                              sourceInjectionTest = args.sourceInjectionTest, MPIEnabled = args.MPIEnabled,
                              strictMPIExceptions = strictMPIExceptions)
    
    optimalCatalogFileName=config.rootOutDir+os.path.sep+"%s_optimalCatalog.csv" % (os.path.split(config.rootOutDir)[-1])
    if os.path.exists(optimalCatalogFileName) == False:
        optimalCatalog=pipelines.filterMapsAndMakeCatalogs(config)
        #print("... rank = %d; catalog length = %d" % (config.rank, len(optimalCatalog)))
        if config.MPIEnabled == True:
            config.comm.barrier()
            optimalCatalogList=config.comm.gather(optimalCatalog, root = 0)
            if config.rank == 0:
                print("... gathered catalogs ...")
                toStack=[]  # We sometimes return [] if no objects found - we can't vstack those
                for collectedTab in optimalCatalogList:
                    if type(collectedTab) == astropy.table.table.Table:
                        toStack.append(collectedTab)
                optimalCatalog=atpy.vstack(toStack)
                # Strip out duplicates (this is necessary when run in tileDir mode under MPI)
                if len(optimalCatalog) > 0:
                    optimalCatalog, numDuplicatesFound, names=catalogs.removeDuplicates(optimalCatalog)
        if config.rank == 0:
            optimalCatalog=catalogs.flagTileBoundarySplits(optimalCatalog)
            optimalCatalog.sort('name')
            catalogs.writeCatalog(optimalCatalog, optimalCatalogFileName)
            catalogs.writeCatalog(optimalCatalog, optimalCatalogFileName.replace(".csv", ".fits"))
            addInfo=[{'key': 'SNR', 'fmt': '%.1f'}]
            catalogs.catalog2DS9(optimalCatalog, optimalCatalogFileName.replace(".csv", ".reg"),
                                 addInfo = addInfo, color = "cyan")
    else:
        if config.rank == 0: print("... already made catalog %s ..." % (optimalCatalogFileName))
                
    # Q function (filter mismatch) - if needed options have been given
    # We may as well do this here to save having to run nemoMass separately (though we still can...)
    # (it's the Q calc in nemoMass that takes time - but it's only a couple of min per field in parallel)
    if 'photFilter' in config.parDict.keys() and config.parDict['photFilter'] is not None and config.parDict['fitQ'] == True:
        if os.path.exists(config.selFnDir+os.path.sep+"QFit.fits") == False:
            signals.fitQ(config)

    # Source injection tests for quantifying position recovery accuracy and noise bias
    if 'sourceInjectionTest' in config.parDict.keys() and config.parDict['sourceInjectionTest'] == True:
        if config.MPIEnabled == True:
            config.comm.barrier()   # Otherwise, some processes can begin before catalog written to disk and then crash
        sourceInjTable=maps.sourceInjectionTest(config)
        if config.MPIEnabled == True:
            config.comm.barrier()
            sourceInjTableList=config.comm.gather(sourceInjTable, root = 0)
            if config.rank == 0:
                print("... gathered source injection test results ...")
                toStack=[]
                for sourceInjTable in sourceInjTableList:
                    if type(sourceInjTable) == astropy.table.table.Table and len(sourceInjTable) > 0:
                        toStack.append(sourceInjTable)
                if len(toStack) > 0:
                    sourceInjTable=atpy.vstack(toStack)
                else:
                    sourceInjTable=None
    else:
        sourceInjTable=None
            
    ## Estimate of contamination from running cluster finding over inverted map
    #if 'estimateContaminationFromInvertedMaps' in list(config.parDict.keys()) and config.parDict['estimateContaminationFromInvertedMaps'] == True:
        #conTabDict=maps.estimateContaminationFromInvertedMaps(config, imageDict)
    #else:
        #conTabDict={}
        
    ## Estimate of contamination by generating a fake sky with noise, and running detection algorithm over it
    ## Ultimately we want this and the above ^^^ to appear on the same plot for comparison
    #if 'estimateContaminationFromSkySim' in list(config.parDict.keys()) and config.parDict['estimateContaminationFromSkySim'] == True:
        #skySimConTabDict=maps.estimateContaminationFromSkySim(config, imageDict) 
    #else:
        #skySimConTabDict={}
    
    ## This just combines inverted maps contamination results and skySim (under different keys)
    ## So we only feed one dictionary into the plotting routine (see below)
    #for k in list(skySimConTabDict.keys()):
        #conTabDict[k]=skySimConTabDict[k]
        
    # Moved calculation of selection function parts here as it's very quick in parallel
    if 'calcSelFn' in list(config.parDict.keys()) and config.parDict['calcSelFn'] == True:
        # Since a fiducial cosmology (OmegaM0 = 0.3, OmegaL0 = 0.7, H0 = 70 km/s/Mpc) was used in the object detection/filtering stage, we use the same one here      
        minMass=8e13
        areaDeg2=400.0  # Don't care what value this has, as we'll sim up an arbitrary number of clusters anyway
        zMin=0.0
        zMax=2.0
        H0=70.
        Om0=0.30
        Ob0=0.05
        sigma8=0.8
        ns=0.95
        mockSurvey=MockSurvey.MockSurvey(minMass, areaDeg2, zMin, zMax, H0, Om0, Ob0, sigma8, ns, enableDrawSample = True)
        selFnCollection=pipelines.makeSelFnCollection(config, mockSurvey)
        if config.MPIEnabled == True:
            config.comm.barrier()
            gathered_selFnCollections=config.comm.gather(selFnCollection, root = 0)
            if config.rank == 0:
                print("... gathered selection function results ...")
                all_selFnCollection={'full': []}
                for key in selFnCollection.keys():
                    if key not in all_selFnCollection.keys():
                        all_selFnCollection[key]=[]
                for selFnCollection in gathered_selFnCollections:
                    for key in all_selFnCollection.keys():
                        all_selFnCollection[key]=all_selFnCollection[key]+selFnCollection[key]
                selFnCollection=all_selFnCollection
    else:
        selFnCollection={}
        
    # Tidying up etc.
    if config.rank == 0:

        # Stitch together map tiles - full fat versions (this will only work 'saveFilteredMaps: True')
        if 'stitchTiles' in config.parDict.keys() and config.parDict['stitchTiles'] == True:
            maps.stitchTiles(config)

        # Stitch together map tiles - 'quicklook' images (downsampled by factor 4 in resolution)
        # The stitchTilesQuickLook routine will only write output if there are multiple maps matching the file pattern
        if 'makeQuickLookMaps' in config.parDict.keys() and config.parDict['makeQuickLookMaps'] == True:
            # RMS (we don't save quicklook to selFnDir - users should use full-fat maps to be sure)
            if os.path.exists(config.diagnosticsDir+os.path.sep+"quicklook_RMSMap.fits") == False:
                maps.stitchTilesQuickLook(config.selFnDir+os.path.sep+"RMSMap*.fits",
                                          config.diagnosticsDir+os.path.sep+"quicklook_RMSMap.fits",
                                          config.quicklookWCS, config.quicklookShape, fluxRescale = config.quicklookScale)
            else:
                print("... already made %s ..." % (config.diagnosticsDir+os.path.sep+"quicklook_RMSMap.fits"))
            # S/N maps at reference filter scale
            quicklookSNMapPath=config.filteredMapsDir+os.path.sep+"quicklook_%s_SNMap.fits" % (config.parDict['photFilter'])
            if config.parDict['photFilter'] is not None and os.path.exists(quicklookSNMapPath) == False:
                maps.stitchTilesQuickLook(config.filteredMapsDir+os.path.sep+"*"+os.path.sep+"%s*SNMap.fits" % (config.parDict['photFilter']),
                                        quicklookSNMapPath, config.quicklookWCS, config.quicklookShape,
                                        fluxRescale = config.quicklookScale)
        
        # Plot tile-averaged position recovery test
        if sourceInjTable is not None:
            sourceInjTable.meta['NEMOVER']=nemo.__version__
            sourceInjTable.write(config.diagnosticsDir+os.path.sep+"sourceInjectionData.fits", overwrite = True) 
            maps.positionRecoveryAnalysis(sourceInjTable, config.diagnosticsDir+os.path.sep+"positionRecovery.pdf",
                                          percentiles = [50, 95, 99.7], plotRawData = True,
                                          pickleFileName = config.diagnosticsDir+os.path.sep+'positionRecovery.pkl',
                                          selFnDir = config.selFnDir)
            #maps.noiseBiasAnalysis(sourceInjTable, config.diagnosticsDir+os.path.sep+"noiseBias.pdf")
            
        ## Plot contamination together
        #if conTabDict != {}:
            #maps.plotContamination(conTabDict, config.diagnosticsDir)           

        # Cache file containing weights for relativistic corrections
        # Saves doing this later (e.g., when nemoMass or nemoSelFn run) and it's quick to do
        signals.getFRelWeights(config)

        if selFnCollection != {}:
            # Survey completeness stats now all lumped together in one routine
            # This also make survey-averaged (M, z) grid(s) as used by e.g. HSC lensing analysis
            completeness.completenessByFootprint(selFnCollection, mockSurvey, config.diagnosticsDir, 
                                                additionalLabel = "_"+config.parDict['selFnOptions']['method'].replace(" ", "_"))
            # If we made mass limit maps...
            # ... make cumulative area versus mass limit plot(s)
            # ... and downsampled full area (untiled) map(s) and plot(s) of the mass limit
            if 'massLimitMaps' in config.parDict['selFnOptions'].keys():
                print(">>> Making cumulative area plots and full survey mass limit plots ...")
                for massLimitDict in config.parDict['selFnOptions']['massLimitMaps']:
                    completeness.cumulativeAreaMassLimitPlot(massLimitDict['z'], config.diagnosticsDir, config.selFnDir, config.allTileNames) 
                    completeness.makeFullSurveyMassLimitMapPlot(massLimitDict['z'], config)
            
        # Tidy up by making MEF files and deleting the (potentially 100s) of per-tile files made
        # This also puts Q fits into one file - so we need to run this regardless
        completeness.tidyUp(config)
        
    
