#!/usr/bin/env python

"""

Quick check of one catalog against another - which clusters (in some .fits table) are missing against a 
cluster or candidate list? We can also check survey masks while we're at it. And spit out mass limit 
information at positions of missing clusters.

"""

import os
import sys
import nemo
from nemo import startUp
from nemo import completeness
from nemo import catalogs
import astropy.table as atpy
from astLib import *
import numpy as np
import time
import argparse
#import IPython

#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser=argparse.ArgumentParser("nemoCatalogCheck")
    parser.add_argument("configFileName", help="""A .yml configuration file. By default, the correspondng nemo 
                        output is assumed to be in a directory named configFileName, minus the .yml 
                        extension.""")
    parser.add_argument("catalogFileName", help = """Object catalog to check against nemo output (.fits format).
                        The catalog must contain at least the following columns (alternatives given in 
                        brackets): name, RADeg (or ra, RA), decDeg (or dec, DEC).""")
    parser.add_argument("-r", "--match-radius", dest="matchRadiusArcmin", help = """Cross-matching radius in 
                        arcmin.""", default = 2.5)
    parser.add_argument("-S", "--fixed-SNR-cut", dest="fixedSNRCut", help = """Cut in fixed_SNR used to select
                        nemo cluster candidates.""", default = 4.0)
    args = parser.parse_args()
    
    configFileName=args.configFileName
    catFileName=args.catalogFileName
    config=startUp.NemoConfig(configFileName, setUpMaps = False, MPIEnabled = False, verbose = False)
    outputLabel=os.path.split(configFileName)[-1].replace(".yml", "")
    
    print(">>> Checking catalog %s against nemo output:" % (catFileName))
    optimalCatalogFileName=config.rootOutDir+os.path.sep+"%s_optimalCatalog.fits" % (os.path.split(config.rootOutDir)[-1])
    checkAgainst=atpy.Table().read(optimalCatalogFileName)
    
    tab=atpy.Table().read(catFileName)
    RAKey, decKey=catalogs.getTableRADecKeys(tab)
    mask=np.less(tab[RAKey], 0)     # Sigurd likes -ve RA
    tab[RAKey][mask]=360-abs(tab[RAKey][mask])
    
    # Use selFn routines to determine if objects were in the valid area mask
    selFn=completeness.SelFn(config.selFnDir, args.fixedSNRCut, configFileName = args.configFileName,
                             setUpAreaMask = True, enableCompletenessCalc = False)
    inMask=selFn.checkCoordsInAreaMask(tab[RAKey], tab[decKey])
    maxPossibleMatches=inMask.sum()
    tab.add_column(atpy.Column(inMask, "inMask"))

    print("... %d/%d objects in %s are in the valid area mask for %s ..." % (maxPossibleMatches, len(tab), 
                                                                             catFileName, config.rootOutDir))
    
    outFileName=os.path.split(catFileName)[-1].replace(".fits", "_inMask_%s.fits" % (outputLabel))
    withinMaskTab=tab[inMask]
    withinMaskTab.meta['NEMOVER']=nemo.__version__
    withinMaskTab.write(outFileName, overwrite = True)

    # Cross matching
    #xTab, xCheckAgainst, rDeg=catalogs.crossMatch(tab, checkAgainst, radiusArcmin=args.matchRadiusArcmin)
    xMatchRadiusDeg=args.matchRadiusArcmin/60.
    missing=[]
    for row in tab:
        rDeg=astCoords.calcAngSepDeg(row[RAKey], row[decKey], checkAgainst['RADeg'].data, checkAgainst['decDeg'].data)
        if rDeg.min() > xMatchRadiusDeg:
            missing.append(True)
        else:
            missing.append(False)
    missTab=tab[missing]
    missTab=missTab[np.where(missTab['inMask'] == True)]
    print("... %d/%d maximum possible matches in %s are found within %.1f arcmin of an object in the %s catalog" % (maxPossibleMatches-len(missTab), maxPossibleMatches, catFileName, args.matchRadiusArcmin, config.rootOutDir))
    print("... %d/%d maximum possible matches in %s are NOT found within %.1f arcmin of an object in the %s catalog" % (len(missTab), maxPossibleMatches, catFileName, args.matchRadiusArcmin, config.rootOutDir))
    
    # We could add mass limits at each location here if we wanted...
    outFileName=os.path.split(catFileName)[-1].replace(".fits", "_missed_in_%s.fits" % (os.path.split(optimalCatalogFileName)[-1].replace(".fits", "")))
    missTab.meta['NEMOVER']=nemo.__version__
    missTab.write(outFileName, overwrite = True)
    print("... written missed objects table to %s" % (outFileName))
    
    regFileName=outFileName.replace(".fits", ".reg")
    foundIDKey=False
    for idKeyToUse in ['name', 'id', 'ID', 'Name', 'NAME', 'Cluster']:
        if idKeyToUse in missTab.keys():
            foundIDKey=True
            break
    if foundIDKey == True:
        catalogs.catalog2DS9(missTab, outFileName.replace(".fits", ".reg"), idKeyToUse = idKeyToUse)
        print("... written missed objects DS9 region file to %s" % (regFileName))
    else:
        print("... didn't write DS9 region file as could not find an 'id' or 'name' column to use for labels")
