#!/usr/bin/env python
from cosymlib import Cosymlib, __version__
from cosymlib import file_io
from cosymlib.file_io.tools import print_header, print_footer, print_input_info
from cosymlib.shape import tools
import argparse
import os
import sys
import yaml


def write_reference_structures(vertices, central_atom, directory):
    if central_atom == 0:
        output_references = open(directory + '/L{}_refs.xyz'.format(vertices), 'x')
    else:
        output_references = open(directory + '/ML{}_refs.xyz'.format(vertices), 'x')

    print("\nReference structures in file {}\n".format(output_references.name))
    for label in tools.get_structure_references(vertices):
        ref_structure = tools.get_reference_structure(label, central_atom=central_atom)
        output_references.write(file_io.get_file_xyz_txt(ref_structure))


# positional arguments
parser = argparse.ArgumentParser(description='Shape', allow_abbrev=False)
parser.add_argument(type=str,
                    dest='input_file',
                    nargs='?',
                    default=None,
                    help='input file with structures')
parser.add_argument(type=str,
                    dest="yaml_input",
                    nargs='?',
                    default=None,
                    help='Input file with keywords')

# Main options
parser.add_argument('-m', '--measure',
                    dest='measure',
                    metavar='SH',
                    default=None,
                    help='compute the SH measure of the input structures (use "custom" to use custom structure')
parser.add_argument('-s', '--structure',
                    dest='structure',
                    action='store_true',
                    default=False,
                    help='return the nearest structure to the reference shape')
parser.add_argument('-o', '--output_name',
                    dest='output_name',
                    metavar='filename',
                    default=None,
                    help='store the output into a file')
parser.add_argument('-c', '--central_atom',
                    dest='central_atom',
                    metavar='N',
                    type=int,
                    default=0,
                    help='define central atom as the atom in position N in the input structure')
parser.add_argument('-r', '--references',
                    dest='references',
                    action='store_true',
                    default=False,
                    help='store the coordinates of the reference polyhedra in a file')
parser.add_argument('-m_custom', '--measure_custom',
                    dest='measure_custom',
                    metavar='filename',
                    default=None,
                    help='define filename containing the structure/s to be used as reference')

# Extra options
parser.add_argument('-l', '--labels', action='store_true',
                    dest='labels',
                    default=False,
                    help='show the reference shape labels')
parser.add_argument('--info',
                    action='store_true',
                    default=False,
                    help='print information about the input structures')
parser.add_argument('-v', '--version',
                    dest='version',
                    action='store_true',
                    default=False,
                    help='print information about the input structures')
parser.add_argument('--labels_n',
                    dest='labels_n',
                    default=False,
                    help='show the reference shape labels of n vertices')
parser.add_argument('--references_n',
                    dest='references_n',
                    default=False,
                    help='store the coordinates of the reference polyhedra of n vertices in a file')

# Modifiers
parser.add_argument('--fix_permutation',
                    dest='fix_permutation',
                    action='store_true',
                    default=False,
                    help='do not permute atoms')

args = parser.parse_args()

if args.yaml_input:
    with open(args.yaml_input, 'r') as stream:
        input_parameters = yaml.load(stream, Loader=yaml.FullLoader)

    for key, value in input_parameters.items():
        if key.lower() in args:
            setattr(args, key.lower(), value)
        else:
            raise KeyError("Key %s is not valid" % key)

if args.version:
    print('Cosymlib version = {}'.format(__version__))
    exit()

common_output = open(args.output_name, 'w') if args.output_name is not None else sys.stdout
print_header(common_output)

if args.labels_n:
    common_output.write(tools.get_shape_label_info(int(args.labels_n), with_central_atom=args.central_atom))
    exit()

if args.references_n:
    input_dir = os.getcwd()
    write_reference_structures(int(args.references_n), args.central_atom, input_dir)
    exit()

if args.input_file is None:
    parser.error('No input file selected! An existing file must be provide')

structures = file_io.read_generic_structure_file(args.input_file, read_multiple=True)
structure_set = Cosymlib(structures)

n_atoms = structure_set.get_n_atoms()
vertices = n_atoms if args.central_atom == 0 else n_atoms - 1

if args.references:
    input_dir = os.path.dirname(args.input_file)
    write_reference_structures(vertices, args.central_atom, input_dir)

if args.info:
    print_input_info(structure_set.get_geometries(), output=common_output)
    exit()

# Shape's commands
if args.labels:
    common_output.write(tools.get_shape_label_info(n_atoms, with_central_atom=args.central_atom))
    exit()

if args.measure_custom:
    reference = file_io.read_generic_structure_file(args.measure_custom, read_multiple=True)
    [x.set_positions(args.central_atom - 1) for x in reference]
    args.measure = 'custom'
elif args.measure == 'all':
    reference = tools.get_structure_references(vertices)
else:
    reference = [args.measure]

if args.structure:
    if common_output is sys.stdout:
        file_name, file_extension = os.path.splitext(args.input_file)
        output_str = open(file_name + '_near.xyz', 'w')
    else:
        output_str = common_output
    structure_set.print_shape_structure(reference,
                                        central_atom=args.central_atom,
                                        fix_permutation=args.fix_permutation,
                                        output=output_str)
if args.measure:
    structure_set.print_shape_measure(reference,
                                      central_atom=args.central_atom,
                                      fix_permutation=args.fix_permutation,
                                      output=common_output)

print_footer(common_output)
