#!/usr/bin/env python

"""

Make a map containing model objects using the info in a nemo output catalog

"""

import os
import sys
import numpy as np
import astropy.table as atpy
from astLib import *
from nemo import startUp
from nemo import maps
import argparse
import astropy.io.fits as pyfits

#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser=argparse.ArgumentParser("nemoModel")
    parser.add_argument("catalogFileName", help = """A catalog (FITS table format) as produced by nemo.""")
    parser.add_argument("maskFileName", help = """A FITS image file, containing a mask of the desired sky
                        area. Non-zero values in the mask are used to define tiles (typically 10 x 5 deg),
                        which are processed in parallel if MPI is enabled (see -M switch). The output sky
                        model image will have the same pixelization as the mask image.""")
    parser.add_argument("beamFileName", help = """A file containing the beam profile, in the standard format
                        used by ACT.""")
    parser.add_argument("outputFileName", help = """The name of the output file that will contain the sky
                        model image.""")
    parser.add_argument("-f", "--frequency-GHz", dest = "obsFreqGHz", type = float, default = 150.0,
                        help = """If the nemo catalog contains SZ-selected clusters, the SZ signal will be
                        evaluted at the given frequency, ignoring relativistic effects 
                        (default: 150.0).""")
    parser.add_argument("-s", "--scale-signals", dest = "scale", type = float, default = 0.0,
                        help = """Scale the input y_c values of clusters in the catalog by this factor.""")
    parser.add_argument("-M", "--mpi", dest="MPIEnabled", action="store_true", help="""Enable MPI. If used,
                        the image will be broken into a number of tiles, with one tile per process. If you
                        want to use this, run with e.g., mpiexec -np 4 nemoModel args -M""", 
                        default = False)
    parser.add_argument("-n", "--no-strict-errors", dest="noStrictMPIExceptions", action="store_true",
                        help="""Disable strict exception handling (applies under MPI only, i.e., must be
                        used with the -M switch). If you use this option, you will get the full traceback
                        when a Python Exception is triggered, but the code may not terminate. This is due
                        to the Exception handling in mpi4py.""",
                        default = False)
    args = parser.parse_args()
    
    if args.noStrictMPIExceptions == True:
        strictMPIExceptions=False
    else:
        strictMPIExceptions=True

    # Create a stub config (then we can use existing start-up stuff to dish out the tiles)
    print(">>> Setting up ...")
    parDict={}
    mapDict={}
    mapDict['mapFileName']=args.maskFileName
    mapDict['obsFreqGHz']=args.obsFreqGHz
    mapDict['beamFileName']=args.beamFileName
    parDict['unfilteredMaps']=[mapDict]
    parDict['mapFilters']=[]
    #if args.MPIEnabled == True:
    parDict['makeTileDir']=True
    parDict['tileOverlapDeg']=1.0
    parDict['tileDefLabel']='auto'
    parDict['tileDefinitions']={'mask': args.maskFileName,
                                'targetTileWidthDeg': 10.0, 'targetTileHeightDeg': 5.0}

    config=startUp.NemoConfig(parDict, MPIEnabled = args.MPIEnabled, divideTilesByProcesses = True,
                              makeOutputDirs = False, setUpMaps = True, writeTileDir = False, 
                              verbose = False, strictMPIExceptions = strictMPIExceptions)
    
    tab=atpy.Table().read(args.catalogFileName)
    
    # Optional signal scaling (useful for diff alpha sims)
    if 'y_c' in tab.keys():
        tab['y_c']=tab['y_c']*args.scale
    
    # Build a dictionary containing the model images which we'll stich together at the end
    print(">>> Building models in tiles ...")
    modelsDict={}
    for tileName in config.tileNames:
        print("... %s ..." % (tileName))
        shape=(config.tileCoordsDict[tileName]['clippedSection'][3]-config.tileCoordsDict[tileName]['clippedSection'][2], 
               config.tileCoordsDict[tileName]['clippedSection'][1]-config.tileCoordsDict[tileName]['clippedSection'][0])
        wcs=astWCS.WCS(config.tileCoordsDict[tileName]['header'], mode = 'pyfits')
        try:
            modelsDict[tileName]=maps.makeModelImage(shape, wcs, tab, args.beamFileName, 
                                                     obsFreqGHz = args.obsFreqGHz,
                                                     validAreaSection = config.tileCoordsDict[tileName]['areaMaskInClipSection'])
        except:
            raise Exception("makeModelImage failed on tile '%s'" % (tileName))

    # Gathering
    #if config.MPIEnabled == True:
        #config.comm.barrier()
        #gathered_modelsDicts=config.comm.gather(modelsDict, root = 0)
        #if config.rank == 0:
            #print("... gathered sky model tiles ...")
            #for tileModelDict in gathered_modelsDicts:
                #for tileName in tileModelDict.keys():
                    #modelsDict[tileName]=tileModelDict[tileName]
    
    # We can't just gather as that triggers a 32-bit overflow (?)
    # So, sending one object at a time
    if config.MPIEnabled == True:
        config.comm.barrier()
        if config.rank > 0:
            print("... rank = %d sending sky model tiles ..." % (config.rank))
            config.comm.send(modelsDict, dest = 0)
        elif config.rank == 0:
            print("... gathering sky model tiles ...")
            gathered_modelsDicts=[]
            gathered_modelsDicts.append(modelsDict)
            for source in range(1, config.size):
                gathered_modelsDicts.append(config.comm.recv(source = source))
            for tileModelDict in gathered_modelsDicts:
                for tileName in tileModelDict.keys():
                    modelsDict[tileName]=tileModelDict[tileName]

    # Stitching
    print(">>> Stitching tiles ...")
    if config.rank == 0:
        d=np.zeros([config.origWCS.header['NAXIS2'], config.origWCS.header['NAXIS1']])
        wcs=config.origWCS
        for tileName in modelsDict.keys():
            print("... %s ..." % (tileName))
            minX, maxX, minY, maxY=config.tileCoordsDict[tileName]['clippedSection']
            if modelsDict[tileName] is not None:
                d[minY:maxY, minX:maxX]=d[minY:maxY, minX:maxX]+modelsDict[tileName]
        astImages.saveFITS(args.outputFileName, d, wcs)
