#!/usr/bin/env python3

import argparse
import sys
import logging

from natorbs import version as natorbs_version
from natorbs import read_molecule, NaturalOrbitals, NOCV, FragmentMullikenAnalysis, DiffDensOrbitals, mulliken_contribs, mutual_overlaps


#### main ####
parser = argparse.ArgumentParser(
    description='Computes natural orbitals (NOs) and related stuff',
    epilog = """'INFILE' is either a Molden format file or anything 
parsable by cclib (tested with Gaussian). At least two INFILEs are required for --nocv, --ddo, --diff. One INFILE is required for default usage (natural orbitals), --spin, and --misc with --mulliken. Two INFILEs are required for --misc with --overlaps""")
parser.add_argument('orbfile', metavar='INFILE', type=str, nargs='+',
                        help='file(s) with input orbitals')
parser.add_argument('-o', '--output', metavar='OUFILE', type=str,
                        help='output file for saving orbitals in Molden format (default: append \'natorbs.out\' to (first) FILE name)')
parser.add_argument('-t', '--thresh', metavar='THRESHOLD', type=float,
                        default = 0.1,
                        help='threshold for eigenvalues')
gp_target = parser.add_mutually_exclusive_group()
gp_target.add_argument('-s', '--spin', action='store_const',
                           dest='target', const='spin', default='natural',
                           help='generate spin natural orbitals (default: NOs)')
gp_target.add_argument('--nocv', action='store_const',
                           dest='target', const='nocv', default='natural',
                           help='generate NOCVs (default: NOs)')
gp_target.add_argument('--diff', action='store_const',
                           dest='target', const='diff', default='natural',
                           help='differential Mulliken analysis (exptl)')
gp_target.add_argument('--ddo', action='store_const',
                           dest='target', const='ddo', default='natural',
                           help='generate DDOs (default: NOs)')
gp_target.add_argument('--misc', action='store_const',
                           dest='target', const='misc', default='natural',
                           help='do some other analysis (default: generate NOs)')
gp_misc = parser.add_mutually_exclusive_group()
gp_misc.add_argument('-p', '--mulliken', metavar='list_of_orbitals', type=str,
                        help='(with --misc) compute Mulliken atomic shares for selected orbitals. Argument \'list_of_orbitals\' should be comma separated list of orbitals, e.g., \'136,145\' or \'213\'')
gp_misc.add_argument('--overlaps', metavar='range_of_orbitals', type=str,
                        help='(with --misc) compute overlaps between a given range of orbitals in two sets for the same molecule. Argument \'range_of_orbitals\' should be two numbers separated with a dash, e.g. \'8-10\'') # FIXME: later we must check consistency 
parser.add_argument('-1', '--ignore-symmetry', action='store_true',
                        help='ignore symmetry, i.e., merge orbitals of all irreps into one set')
parser.add_argument('-r', '--rebase-coeffs', action='store_true',
                        help='rebase LCAO coeffs (needed for CRYSTAL)')
gp_verbosity = parser.add_mutually_exclusive_group()
gp_verbosity.add_argument('-v', '--verbose', help='set loglevel verbose',
                              action='store_const', dest='verbose_level',
                              const=logging.INFO, default=logging.WARNING)
gp_verbosity.add_argument('-g', '--debug', help='set loglevel debug',
                            action='store_const', dest='verbose_level',
                            const=logging.DEBUG, default=logging.WARNING)
parser.add_argument('--version', action='version', version='%(prog)s v' + str(natorbs_version))

args = parser.parse_args()

# setup logger
logger = logging.getLogger("natorbs")
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(logging.Formatter("[%(name)s %(levelname)s] %(message)s"))
logger.addHandler(handler)
logger.setLevel(args.verbose_level)

if args.target in ('natural', 'spin'):
    if len(args.orbfile) != 1:
        parser.print_usage()
        print('\nERROR: There should be only one INFILE file for natural/spin orbitals')
        parser.exit(1)

    else:
        logger.info("Reading molecule from file: %s" %args.orbfile[0])
        mol = read_molecule(args.orbfile[0],
            loglevel=args.verbose_level, rebase_coeffs=args.rebase_coeffs)

        if args.target == 'spin':
            message = "Spin density orbitals will be computed"
            spin = True
        else:
            message = "Natural orbitals will be computed"
            spin = False
        logger.info(message)
        result = NaturalOrbitals(
            mol, spin, ignore_symmetry=args.ignore_symmetry, loglevel=args.verbose_level)    

        logger.debug("Writing results")
        if spin:
            filter_func = lambda o, e: o >= args.thresh
        else:
            filter_func = lambda o, e: o >= args.thresh and 2-o >= args.thresh
        if args.output:
            outfile = open(args.output, 'w')
        else:
            outfile = open(args.orbfile[0] + '.natorbs.out', 'w')
        result.write_molden_format(outfile, filter_func)
        outfile.close()
        sys.exit(0)
        
elif args.target in ('nocv', 'ddo', 'diff'):
    if len(args.orbfile) == 1:
        parser.print_usage()
        print('\nERROR: At least two INFILEs are required for --ddo, --diff, or --nocv')
        parser.exit(1)

    else:
        mol = []   # 0 = molecule, [1:] = promolecule
        for i, inputfilename in enumerate(args.orbfile):
            logger.info("Reading molecule (%d) from file: %s" %(i, inputfilename))
            mol.append(read_molecule(
                inputfilename, loglevel=args.verbose_level, rebase_coeffs=args.rebase_coeffs))
            #FIXME: add logger to read_molecule TODO!  # 2020-05-15: already done or still to do?

        if args.target == 'nocv':
            logger.info("NOCV mode selected with %d fragments." %len(mol[1:]))
            result = NOCV(mol[0], mol[1:], loglevel=args.verbose_level)

        elif args.target == 'diff': # FIXME: what is this?
            logger.info(
                "Diff. Mulliken analysis (experimental) with %d fragments." %len(mol[1:]))
            result = FragmentMullikenAnalysis(mol[0], mol[1:],
                                      args.verbose_level, "", spin=False)
            # FIXME: find a way to pass spin=True
            
            # FIXME: depends on where flows're written (as Ene or Occ)!
            # At the moment we chose to write them as ENERGIES,
            # but it might be changed later.
            print(sum(result.moenergies[0]))

        else: # args.target == 'ddo'
            logger.info("DDO mode (experimental) with %d fragments." %len(mol[1:]))
            result = DiffDensOrbitals(mol[0], mol[1:], loglevel=args.verbose_level)
            print(filter(lambda x: abs(x) > 0.05, result.mooccups[0]))
            print(sum(result.mooccups[0]))
            print(sum(filter(lambda x: abs(x) > 0.05, result.mooccups[0])))

        logger.debug('Writing results')
        outfile = open(args.output, 'w')
        result.write_molden_format(outfile, lambda o, e: abs(o) >= args.thresh)
        outfile.close()
        sys.exit(0)
            
elif args.target == 'misc':

    if args.mulliken is None and args.overlaps is None:
        parser.print_usage()
        print('\nERROR: either --mulliken or --overlaps is required when --misc is used')
        parser.exit(1)
        
    if args.mulliken:
        if len(args.orbfile) > 1:
            parser.print_usage()
            print('Option --mulliken can be used only with one INFILE')
            parser.exit(1)
        else:
            logger.info("Reading molecule from file: %s" %args.orbfile[0])
            mol = read_molecule(args.orbfile[0],
                loglevel=args.verbose_level, rebase_coeffs=args.rebase_coeffs)
            logger.debug("Determining list of MOs")
            list_of_mos = list(map(lambda x: int(x)-1, args.mulliken(',')))
            logger.info("Mulliken population analysis of selected MOs")
            r = mulliken_contribs(mol, list_of_mos)
            print(r)
            sys.exit(0)

    else: # args.target == 'overlaps':
        if len(args.orbfile) != 2:
            parser.print_usage()
            print('Option --overlaps requires two INFILEs')
            parser.exit(1)
        else:
            logger.debug("Determining range of MOs")
            range_of_mos = tuple(map(lambda x: int(x)-1, args.overlaps.split('-')))

            logger.info("Reading first set of orbitals from file: %s" %args.orbfile[0])
            set1 = read_molecule(args.orbfile[0],
                loglevel=args.verbose_level, rebase_coeffs=args.rebase_coeffs)
            
            logger.info("Reading second set of orbitals from file: %s" %args.orbfile[1])
            set2 = read_molecule(args.orbfile[1],
                loglevel=args.verbose_level, rebase_coeffs=args.rebase_coeffs)

            r = mutual_overlaps(set1, set2, range_of_mos)
            i1, i2 = range_of_mos
            k = i2 - i1 + 1
            for i in range(0,k):
                for j in range(0,k):
                    print (i+i1+1, j+i1+1, r[i,j])
            sys.exit(0)

