#!/usr/bin/env python

"""

Calculate mass completeness limits, assuming the RMS map(s) are correct

"""

import os
import sys
import resource
import glob
import numpy as np
import pylab as plt
import astropy.table as atpy
from astLib import *
from scipy import stats
from scipy import interpolate
from scipy import ndimage
from scipy import optimize
from nemo import signals
from nemo import maps
from nemo import MockSurvey
from nemo import completeness
from nemo import plotSettings
from nemo import startUp
from nemo import pipelines
import argparse
import types
import pickle
import astropy.io.fits as pyfits
import time
import yaml
#import IPython
plt.matplotlib.interactive(False)

# If want to catch warnings as errors...
#import warnings
#warnings.filterwarnings('error')
    
#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser=argparse.ArgumentParser("nemoSelFn")
    parser.add_argument("configFileName", help="""A .yml configuration file.""")
    parser.add_argument("-M", "--mpi", dest="MPIEnabled", action="store_true", help="""Enable MPI. If you 
                        want to use this, run with e.g., mpiexec -np 4 nemoSelFn configFile.yml -M""", 
                        default = False)
    args = parser.parse_args()
    
    parDictFileName=args.configFileName
    config=startUp.NemoConfig(parDictFileName, MPIEnabled = args.MPIEnabled)

    # Since a fiducial cosmology (OmegaM0 = 0.3, OmegaL0 = 0.7, H0 = 70 km/s/Mpc) was used in the object detection/filtering stage, we use the same one here      
    minMass=8e13
    areaDeg2=400.0  # Don't care what value this has, as we'll sim up an arbitrary number of clusters anyway
    zMin=0.0
    zMax=2.0
    H0=70.
    Om0=0.30
    Ob0=0.05
    sigma8=0.8
    ns=0.95
    mockSurvey=MockSurvey.MockSurvey(minMass, areaDeg2, zMin, zMax, H0, Om0, Ob0, sigma8, ns, enableDrawSample = True)
    
    selFnCollection=pipelines.makeSelFnCollection(config, mockSurvey)
    
    # MPI: gather together selection function results, so we can compute survey-wide average
    if config.MPIEnabled == True:
        gathered_selFnCollections=config.comm.gather(selFnCollection, root = 0)
        if config.rank != 0:
            assert gathered_selFnCollections is None
            print("... MPI rank %d finished ..." % (config.rank))
            sys.exit()
        else:
            print("... gathering selection function results ...")
            all_selFnCollection={'full': []}
            for key in selFnCollection.keys():
                if key not in all_selFnCollection.keys():
                    all_selFnCollection[key]=[]
            for selFnCollection in gathered_selFnCollections:
                for key in all_selFnCollection.keys():
                    all_selFnCollection[key]=all_selFnCollection[key]+selFnCollection[key]
            selFnCollection=all_selFnCollection
    
    # Survey completeness stats now all lumped together in one routine
    # This also make survey-averaged (M, z) grid(s) as used by e.g. HSC lensing analysis
    completeness.completenessByFootprint(selFnCollection, mockSurvey, config.diagnosticsDir, 
                                         additionalLabel = "_"+config.parDict['selFnOptions']['method'].replace(" ", "_"))
    
    # If we made mass limit maps...
    # ... make cumulative area versus mass limit plot(s)
    # ... and downsampled full area (untiled) map(s) and plot(s) of the mass limit
    if 'massLimitMaps' in config.parDict['selFnOptions'].keys():
        print(">>> Making cumulative area plots and full survey mass limit plots ...")
        for massLimitDict in config.parDict['selFnOptions']['massLimitMaps']:
            completeness.cumulativeAreaMassLimitPlot(massLimitDict['z'], config.diagnosticsDir, config.selFnDir, config.allTileNames) 
            completeness.makeFullSurveyMassLimitMapPlot(massLimitDict['z'], config)
            
    # Tidy up by making MEF files and deleting the (potentially 100s) of per-tile files made
    completeness.tidyUp(config)
