#!/usr/bin/env python
import argparse
import numpy as np
import yaml
import sys
from cosymlib import file_io
from cosymlib.molecule.geometry import Geometry
from cosymlib import Cosymlib
from cosymlib.shape import tools
import os


parser = argparse.ArgumentParser(description='Cosym')
parser.add_argument(type=str, dest='input_file', help='input file name(+extension)')
parser.add_argument(type=str, dest="yaml_input", nargs='?', default=None,
                    help='Perform the calculations with the command file')
parser.add_argument('-o', '--output_name', dest='output_name', default=None, help='save in file name')
parser.add_argument('-info', action='store_true', default=False, help='return information about the input geometries')

parser.add_argument('-c', '--central_atom', action='store', dest='central_atom',
                    type=int, default=0, help='position of the central atom if exist')
parser.add_argument('-custom_ref', action='store', dest='custom_ref', default=None,
                    help='take a given structure from the file and use it as reference')
parser.add_argument('-fix_perm', '--fix_permutation', dest='fix_permutation', action='store_true', default=False,
                    help='use the given permutation to perform a calculation')
parser.add_argument('-no_labels', '--ignore_atom_labels', dest='ignore_atom_labels', action='store_true', default=False,
                    help='ignore atom labels from given structures')
# parser.add_argument('-connectivity', dest='connectivity', action='store', default=None,
#                     help='Connect a set of atoms by...')
parser.add_argument('-axis1', dest='axis1', action='store', default=None, nargs=3,
                    help='First reference axis for the symmetry calculation')
parser.add_argument('-axis2', dest='axis2', action='store', default=None, nargs=3,
                    help='Second reference axis for the symmetry calculation')
parser.add_argument('-center', dest='center', action='store', default=None, nargs=3,
                    help='Center for the symmetry calculation')

# Shape input flags
group_shape = parser.add_argument_group('Shape')
group_shape.add_argument('-shp_m', '--shp_measure',
                         dest='shp_measure',
                         action='store',
                         default=None,
                         help='Shape measure of input structure with reference polyhedra')
group_shape.add_argument('-shp_l', '--shp_labels', action='store_true',
                         dest='shp_labels',
                         default=False,
                         help='show the reference labels for a given structure')
group_shape.add_argument('-shp_s', '--shp_structure',
                         dest='shp_structure',
                         action='store_true',
                         default=False,
                         help='return the closes input structure to the reference shape')
group_shape.add_argument('-shp_r', '--shp_references',
                         dest='shp_references',
                         action='store_true',
                         default=False,
                         help='return a file with the coordinates of reference polyhedra')


# PointGroup input flags
group_pointgroup = parser.add_argument_group('PointGroup')
group_pointgroup.add_argument('-pointgroup',
                              dest='pointgroup',
                              action='store_true',
                              default=False,
                              help='Gives the point group of an input structure')

# Symgroup input flags
group_symgroup = parser.add_argument_group('Symgroup')
group_symgroup.add_argument('-sym_m','--sym_measure',
                            dest='sym_measure',
                            action='store',
                            default=None,
                            help='Symgroup measure of input structure with reference group')
group_symgroup.add_argument('-sym_s','--sym_structure',
                            dest='sym_structure',
                            action='store_true',
                            default=False,
                            help='return the closes input structure to the reference symmetry')
group_symgroup.add_argument('-sym_l', '--sym_labels',
                            dest='sym_labels',
                            action='store_true',
                            default=False,
                            help='return the possible symmetry labels that can be used in symgroup')
group_symgroup.add_argument('-sym_chirality',
                            dest='sym_chirality',
                            action='store_true',
                            default=False,
                            help='search for a possible chirality in molecule')

# Qsym input flags
group_qsym = parser.add_argument_group('Qsym')
group_qsym.add_argument('-qsym_wfn',
                        dest='qsym_wfn',
                        action='store',
                        default=False,
                        help='Wfnsym measure of input structure with reference group')
group_qsym.add_argument('-qsym_dens',
                        dest='qsym_dens',
                        action='store',
                        default=False,
                        help='Density measure of input structure with reference group')
group_qsym.add_argument('-qsym_mos',
                        dest='qsym_wfn',
                        action='store',
                        default=False,
                        help='Range of orbitals to analyze density or wavefunction symmetry')

# Utils
parser.add_argument('-mo_diagram', dest="mo_diagram", action='store_true', default=False,
                    help='Perform the calculations with the command file')


args = parser.parse_args(sys.argv[1:])
print('Starting...')

if args.yaml_input:
    with open(args.yaml_input, 'r') as stream:
        input_parameters = yaml.load(stream, Loader=yaml.FullLoader)

    for key, value in input_parameters.items():
        if key.lower() in args:
            setattr(args, key.lower(), value)
        else:
            raise KeyError("Key %s is not valid" % key)

if args.input_file is not None:
    structures = file_io.read_generic_structure_file(args.input_file, read_multiple=True)
    symobj = Cosymlib(structures)

    try:
        n_atoms = structures[0].geometry.get_n_atoms()
    except AttributeError:
        try:
            n_atoms = structures[0].get_n_atoms()
        except AttributeError:
            raise AttributeError('Molecule object not found')
    if args.central_atom != 0:
        n_atoms -= 1

if args.shp_references:
    test_structure = []
    symbols = ['H' for _ in range(n_atoms)]
    symbols.append('N')
    for label in tools.get_structure_references(n_atoms):
        test_structure.append(Geometry(symbols=symbols,
                                       positions=tools.get_test_structure(label, central_atom=args.central_atom),
                                       name=label))
    input_file, file_extension = os.path.splitext(args.input_file)
    if args.central_atom == 0:
        output = open('{}/L{}.xyz'.format(os.path.dirname(input_file), n_atoms), 'w')
    else:
        output = open('{}/ML{}.xyz'.format(os.path.dirname(input_file), n_atoms), 'w')
    output.write(file_io.get_file_xyz_txt(test_structure))
    # file_io.write_file_xyz(test_structure, output_name='ML{}_ref'.format(n_atoms))

if args.info:
    file_io.write_input_info(structures, output_name=args.output_name)

# Shape's commands
if args.shp_labels:
    if args.central_atom == 0:
        print(tools.get_shape_label_info(n_atoms))
    else:
        print(tools.get_shape_label_info(n_atoms))

reference_polyhedron = []
if args.shp_measure == 'custom':
    reference_polyhedron = file_io.get_geometry_from_file_xyz(args.custom_ref, read_multiple=True)
    [x.set_positions(args.central_atom - 1) for x in reference_polyhedron]
elif not reference_polyhedron:
    if args.shp_measure == 'all':
        try:
            n_atoms = structures[0].geometry.get_n_atoms()
        except AttributeError:
            n_atoms = structures[0].get_n_atoms()
        c = int(bool(args.central_atom))
        reference_polyhedron = tools.get_structure_references(n_atoms - c)
    else:
        reference_polyhedron = [args.shp_measure]

if args.shp_structure:
    symobj.print_shape_structure(reference_polyhedron,
                                 central_atom=args.central_atom,
                                 output=args.output_name,
                                 fix_permutation=args.fix_permutation)

elif args.shp_measure or args.custom_ref:
    symobj.print_shape_measure(reference_polyhedron,
                               central_atom=args.central_atom,
                               output=args.output_name,
                               fix_permutation=args.fix_permutation)


# Symgroup commands
if args.sym_structure:
    symobj.print_symmetry_nearest_structure(args.sym_measure,
                                            central_atom=args.central_atom,
                                            symbols=not args.ignore_atom_labels,
                                            output=args.output_name)
elif args.sym_measure:
    symobj.print_geometric_symmetry_measure(args.sym_measure,
                                            central_atom=args.central_atom,
                                            symbols=not args.ignore_atom_labels,
                                            output=args.output_name)

# Wfnsym commands
if args.qsym_wfn:
    if args.axis1 is not None:
        args.axis1 = np.array(args.axis1).astype(float)
    if args.axis2 is not None:
        args.axis2 = np.array(args.axis2).astype(float)
    if args.center is not None:
        args.center = np.array(args.center).astype(float)
    symobj.OLD_print_wnfsym_measure_verbose(args.qsym_wfn,
                                            axis=args.axis1,
                                            axis2=args.axis2,
                                            center=args.center,
                                            output=args.output_name)

    #Utils
    if args.mo_diagram:
        symobj.plot_mo_diagram(args.qsym_wfn,
                               axis=args.axis1,
                               axis2=args.axis2,
                               center=args.center)

# Pointgroup command
if args.pointgroup:
    for idm, pg in enumerate(symobj.get_point_group()):
        print('The point group of molecule{} is: {}'.format(idm, pg))

print('\nEnd of cosym calculation')