#!/usr/bin/env python
from cosymlib import Cosymlib, __version__
from cosymlib import file_io
from cosymlib.file_io.tools import print_header, print_footer, print_input_info
from cosymlib.shape import tools
from cosymlib.symmetry.tools import orthogonal_c4
import argparse
import sys
import yaml
import os
import numpy as np


def write_reference_structures(vertices, central_atom, directory):
    if central_atom == 0:
        output_references = open(directory + '/L{}_refs.xyz'.format(vertices), 'x')
    else:
        output_references = open(directory + '/ML{}_refs.xyz'.format(vertices), 'x')

    print("\nReference structures in file {}\n".format(output_references.name))
    for label in tools.get_structure_references(vertices):
        ref_structure = tools.get_reference_structure(label, central_atom=central_atom)
        output_references.write(file_io.get_file_xyz_txt(ref_structure))


parser = argparse.ArgumentParser(description='Cosym')
parser.add_argument(type=str, dest='input_file', nargs='?', default=None, help='input file name(+extension)')
parser.add_argument(type=str, dest="yaml_input", nargs='?', default=None,
                    help='Perform the calculations with the command file')
parser.add_argument('-o', '--output_name', dest='output_name', default=None, help='save in file name')
parser.add_argument('--info', action='store_true', default=False, help='return information about the input geometries')
parser.add_argument('-v', '--version', dest='version', action='store_true', default=False,
                    help='print information about the input structures')

parser.add_argument('-c', '--central_atom', action='store', dest='central_atom',
                    type=int, default=0, help='position of the central atom if exist')
parser.add_argument('-shp_m_custom', '--shp_measure_custom', action='store', dest='shp_measure_custom', default=None,
                    help='define filename containing the structure/s to be used as reference')
parser.add_argument('-fix_perm', '--fix_permutation', dest='fix_permutation', action='store_true', default=False,
                    help='use the given permutation to perform a calculation')
parser.add_argument('-no_labels', '--ignore_atom_labels', dest='ignore_atom_labels', action='store_true', default=False,
                    help='ignore atom labels from given structures')
# parser.add_argument('-connectivity', dest='connectivity', action='store', default=None,
#                     help='Connect a set of atoms by...')
parser.add_argument('-axis', dest='axis', action='store', default=None, nargs=3,
                    help='First reference axis for the symmetry calculation')
parser.add_argument('-axis2', dest='axis2', action='store', default=None, nargs=3,
                    help='Second reference axis for the symmetry calculation')
parser.add_argument('-center', dest='center', action='store', default=None, nargs=3,
                    help='Center for the symmetry calculation')
parser.add_argument('-c3_c4', dest='c3_c4', action='store_true', default=False,
                    help='axis is a c3 rotational axis instead of a c4')

# Shape input flags
group_shape = parser.add_argument_group('Shape')
group_shape.add_argument('-shp_m', '--shp_measure',
                         dest='shp_measure',
                         action='store',
                         default=None,
                         help='Shape measure of input structure with reference polyhedra')
group_shape.add_argument('-shp_l', '--shp_labels', action='store_true',
                         dest='shp_labels',
                         default=False,
                         help='show the reference labels for a given structure')
group_shape.add_argument('-shp_s', '--shp_structure',
                         dest='shp_structure',
                         action='store_true',
                         default=False,
                         help='return the closes input structure to the reference shape')
group_shape.add_argument('-shp_r', '--shp_references',
                         dest='shp_references',
                         action='store_true',
                         default=False,
                         help='return a file with the coordinates of reference polyhedra')
group_shape.add_argument('--shp_labels_n',
                         dest='shp_labels_n',
                         default=False,
                         help='show the reference labels of n vertices')
group_shape.add_argument('--shp_references_n',
                         dest='shp_references_n',
                         default=False,
                         help='store the coordinates of the reference polyhedra of n vertices in a file')

# PointGroup input flags
group_pointgroup = parser.add_argument_group('PointGroup')
group_pointgroup.add_argument('-pointgroup',
                              dest='pointgroup',
                              action='store_true',
                              default=False,
                              help='Gives the point group of an input structure')

# Symgroup input flags
group_symgroup = parser.add_argument_group('Symgroup')
group_symgroup.add_argument('-sym_m', '--sym_measure',
                            dest='sym_measure',
                            action='store',
                            default=None,
                            help='Symgroup measure of input structure with reference group')
group_symgroup.add_argument('-sym_s', '--sym_structure',
                            dest='sym_structure',
                            action='store_true',
                            default=False,
                            help='return the closes input structure to the reference symmetry')
group_symgroup.add_argument('-sym_l', '--sym_labels',
                            dest='sym_labels',
                            action='store_true',
                            default=False,
                            help='return the possible symmetry labels that can be used in symgroup')
group_symgroup.add_argument('-sym_chirality',
                            dest='sym_chirality',
                            action='store_true',
                            default=False,
                            help='search for a possible chirality in molecule')

# Qsym input flags
group_qsym = parser.add_argument_group('Qsym')
group_qsym.add_argument('-qsym_wfn',
                        dest='qsym_wfn',
                        action='store',
                        default=False,
                        help='wavefunction symmetry measure of input structure with reference group')
group_qsym.add_argument('-qsym_dens',
                        dest='qsym_dens',
                        action='store',
                        default=False,
                        help='Density symmetry measure of input structure with reference group')
group_qsym.add_argument('-qsym_mos',
                        dest='qsym_mos',
                        action='store',
                        default=False,
                        help='Range of orbitals to analyze density or wavefunction symmetry')

# addtional print
parser.add_argument('-qsym_mo', '--qsym_molecular_orbitals',
                    dest='qsym_molecular_orbitals',
                    default=False,
                    action='store_true',
                    help='print molecular orbitals symmetry elements')
parser.add_argument('--qsym_trans_matrices',
                    dest='qsym_trans_matrices',
                    default=False,
                    action='store_true',
                    help='print the transformation matrices')
parser.add_argument('--qsym_mo_overlaps',
                    dest='qsym_mo_overlaps',
                    default=False,
                    action='store_true',
                    help='print the molecular orbitals overlap')
parser.add_argument('--esym_add_axes',
                    dest='esym_add_axes',
                    action='store_true',
                    default=False,
                    help='create an xyz file of the molecule with axes orientation as dummies')
parser.add_argument('-mo_diagram', dest="mo_diagram", action='store_true', default=False,
                    help='Perform the calculations with the command file')


args = parser.parse_args(sys.argv[1:])

if args.yaml_input:
    with open(args.yaml_input, 'r') as stream:
        input_parameters = yaml.load(stream, Loader=yaml.FullLoader)

    for key, value in input_parameters.items():
        if key.lower() in args:
            setattr(args, key.lower(), value)
        else:
            raise KeyError("Key %s is not valid" % key)

if args.version:
    print('Cosymlib version = {}'.format(__version__))
    exit()

common_output = open(args.output_name, 'w') if args.output_name is not None else sys.stdout
print_header(common_output)

if args.shp_labels_n:
    common_output.write(tools.get_shape_label_info(int(args.shp_labels_n), with_central_atom=args.central_atom))
    exit()

if args.shp_references_n:
    input_dir = os.getcwd()
    write_reference_structures(int(args.shp_references_n), args.central_atom, input_dir)
    exit()

if args.input_file is None:
    parser.error('No input file selected! An existing file must be provide')

structures = file_io.read_generic_structure_file(args.input_file, read_multiple=True)
symobj = Cosymlib(structures)

n_atoms = symobj.get_n_atoms()
vertices = n_atoms if args.central_atom == 0 else n_atoms - 1

if args.shp_references:

    input_dir = os.path.dirname(args.input_file)
    write_reference_structures(vertices, args.central_atom, input_dir)

if args.info:
    print_input_info(structures, output=common_output)
    exit()

# Shape's commands
if args.shp_labels:
    if args.central_atom == 0:
        print(tools.get_shape_label_info(n_atoms))
    else:
        print(tools.get_shape_label_info(n_atoms))

reference = []
# if args.shp_measure == 'custom':
#     if args.custom_ref is not None:
#         reference = file_io.get_geometry_from_file_xyz(args.custom_ref, read_multiple=True)
#         [x.set_positions(args.central_atom - 1) for x in reference]
#     else:
#         sys.exit('Custom reference file not defined (use -cref or --custom_ref filename)')
# elif args.shp_measure == 'all':
#     reference = tools.get_structure_references(vertices)
# else:
#     reference = [args.shp_measure]

if args.shp_measure_custom:
    reference = file_io.read_generic_structure_file(args.shp_measure_custom, read_multiple=True)
    [x.set_positions(args.central_atom - 1) for x in reference]
    args.shp_measure = 'custom'
elif args.shp_measure == 'all':
    reference = tools.get_structure_references(vertices)
elif args.shp_measure:
    reference = [args.shp_measure]

if args.shp_structure:
    if common_output is sys.stdout:
        file_name, file_extension = os.path.splitext(args.input_file)
        output_str = open(file_name + '_near.xyz', 'w')
    else:
        output_str = common_output
    symobj.print_shape_structure(reference,
                                 central_atom=args.central_atom,
                                 fix_permutation=args.fix_permutation,
                                 output=output_str)

if args.shp_measure:
    symobj.print_shape_measure(reference,
                               central_atom=args.central_atom,
                               fix_permutation=args.fix_permutation,
                               output=common_output)


# Symgroup commands
if args.sym_structure:
    symobj.print_symmetry_nearest_structure(args.sym_measure,
                                            central_atom=args.central_atom,
                                            symbols=not args.ignore_atom_labels,
                                            output=common_output)
elif args.sym_measure:
    symobj.print_geometric_symmetry_measure(args.sym_measure,
                                            central_atom=args.central_atom,
                                            symbols=not args.ignore_atom_labels,
                                            output=common_output)

# Wfnsym commands
if args.qsym_wfn:
    if args.axis is not None:
        args.axis = np.array(args.axis).astype(float)
    if args.axis2 is not None:
        args.axis2 = np.array(args.axis2).astype(float)
    if args.center is not None:
        args.center = np.array(args.center).astype(float)
    if args.c3_c4:
        if args.axis is None:
            raise KeyError('Missing axis')
        elif args.axis2 is None:
            raise KeyError('Missing axis2')
        args.axis, args.axis2 = orthogonal_c4(args.axis, args.axis2)

    if args.qsym_molecular_orbitals:
        symobj.print_esym_mo_irreducible_repr(args.qsym_wfn,
                                              axis=args.axis,
                                              axis2=args.axis2,
                                              center=args.center,
                                              show_axes=False,
                                              output=common_output)

    symobj.print_esym_irreducible_repr(args.qsym_wfn,
                                       axis=args.axis,
                                       axis2=args.axis2,
                                       center=args.center,
                                       output=common_output)

    if args.qsym_trans_matrices:
        symobj.print_esym_matrices(args.qsym_wfn,
                                   axis=args.axis,
                                   axis2=args.axis2,
                                   center=args.center,
                                   output=common_output)

    if args.qsym_mo_overlaps:
        symobj.print_esym_sym_overlaps(args.qsym_wfn,
                                       axis=args.axis,
                                       axis2=args.axis2,
                                       center=args.center,
                                       output=common_output)

    # Utils
    if args.mo_diagram:
        symobj.plot_mo_diagram(args.qsym_wfn,
                               axis=args.axis,
                               axis2=args.axis2,
                               center=args.center)

    if args.esym_add_axes:
        file_name, file_extension = os.path.splitext(args.input_file)
        output_str = open(file_name + '_orient.xyz', 'w')
        symobj.print_esym_orientation(args.measure,
                                      axis=args.axis,
                                      axis2=args.axis2,
                                      center=args.center,
                                      output=output_str)

# Pointgroup command
if args.pointgroup:
    for idm, pg in enumerate(symobj.get_point_group()):
        print('The point group of molecule{} is: {}'.format(idm, pg))

print_footer(common_output)
