#!/usr/bin/env python
from cosymlib import Cosymlib, __version__
from cosymlib import file_io
from cosymlib.file_io.tools import print_header, print_footer, print_input_info
from cosymlib.shape import tools
import argparse
import sys
import yaml

parser = argparse.ArgumentParser(description='shape_map ')

# positional arguments
parser.add_argument(type=str,
                    dest='input_file',
                    nargs='?', default=None,
                    help='input file with structures')
parser.add_argument(type=str,
                    dest="yaml_input",
                    nargs='?',
                    default=None,
                    help='Input file with keywords')

# Main options
parser.add_argument('-m_1', '--measure_1',
                    dest='m_1',
                    metavar='SH1',
                    default=None,
                    help='shape label for x axis')
parser.add_argument('-m_2', '--measure_2',
                    dest='m_2',
                    metavar='SH2',
                    default=None,
                    help='shape label for y axis')
parser.add_argument('-m_custom_1', '--measure_custom_1',
                    dest='m_custom_1',
                    metavar='SH1',
                    default=None,
                    help='shape reference structure for x axis')
parser.add_argument('-m_custom_2', '--measure_custom_2',
                    dest='m_custom_2',
                    metavar='SH2',
                    default=None,
                    help='shape reference structure for y axis')
parser.add_argument('-o', '--output_name',
                    dest='output_name',
                    metavar='filename',
                    default=None,
                    help='store the output into a file')
parser.add_argument('-c', '--central_atom',
                    dest='central_atom',
                    metavar='N',
                    type=int,
                    default=0,
                    help='define central atom as the atom in position N in the input structure')

# Extra options
parser.add_argument('-l', '--labels', action='store_true',
                    dest='labels',
                    default=False,
                    help='show the reference shape labels')
parser.add_argument('--info',
                    action='store_true',
                    default=False,
                    help='print information about the input structures')
parser.add_argument('-v', '--version',
                    dest='version',
                    action='store_true',
                    default=False,
                    help='print information about the input structures')

# Modifiers
parser.add_argument('--fix_permutation',
                    dest='fix_permutation',
                    action='store_true',
                    help='do not permute atoms')
parser.add_argument('--min_dev',
                    dest='min_dev',
                    default=None,
                    type=float,
                    help='minimum deviation')
parser.add_argument('--max_dev',
                    dest='max_dev',
                    default=None,
                    type=float,
                    help='maximum deviation')
parser.add_argument('--min_gco',
                    dest='min_gco',
                    default=None,
                    type=float,
                    help='minimum coordinates gradient')
parser.add_argument('--max_gco',
                    dest='max_gco',
                    # action='store_true',
                    default=None,
                    type=float,
                    help='maximum coordinates gradient')
parser.add_argument('--n_points',
                    dest='n_points',
                    default=20,
                    type=int,
                    help='number of path structures to calculate')

args = parser.parse_args()

if args.yaml_input:
    with open(args.yaml_input, 'r') as stream:
        input_parameters = yaml.load(stream, Loader=yaml.FullLoader)

    for key, value in input_parameters.items():
        if key.lower() in args:
            setattr(args, key.lower(), value)
        else:
            raise KeyError("Key %s is not valid" % key)

if args.version:
    print('Cosymlib version = {}'.format(__version__))
    exit()

common_output = open(args.output_name, 'w') if args.output_name is not None else sys.stdout
print_header(common_output)

if args.input_file is None:
    parser.error('No input file selected! An existing file must be provide')

structures = file_io.read_generic_structure_file(args.input_file, read_multiple=True)
structure_set = Cosymlib(structures)
n_atoms = structure_set.get_n_atoms()

if args.info:
    print_input_info(structure_set.get_geometries(), output=common_output)
    exit()

# Shape's commands
if args.labels:
    common_output.write(tools.get_shape_label_info(n_atoms, with_central_atom=args.central_atom))
    exit()

if args.m_1 and args.m_custom_1:
    raise ReferenceError('Either m_1 or m_custom_1 should be pointing to the second axis')
elif args.m_1:
    references = [args.m_1]
elif args.m_custom_1:
    reference_polyhedron = file_io.get_geometry_from_file_xyz(args.m_custom_1, read_multiple=False)
    reference_polyhedron.set_positions(args.central_atom - 1)
    references = [reference_polyhedron]

if args.m_2 and args.m_custom_2:
    raise ReferenceError('Either m_2 or m_custom_2 should be pointing to the first axis')
elif args.m_2:
    references.append(args.m_2)
elif args.m_custom_2:
    reference_polyhedron = file_io.get_geometry_from_file_xyz(args.m_custom_1, read_multiple=False)
    reference_polyhedron.set_positions(args.central_atom - 1)
    references.append(reference_polyhedron)

if args.m_1 or args.m_custom_1:
    structure_set.print_minimum_distortion_path_shape(references[0],
                                                      references[1],
                                                      central_atom=args.central_atom,
                                                      min_dev=args.min_dev,
                                                      max_dev=args.max_dev,
                                                      min_gco=args.min_gco,
                                                      max_gco=args.max_gco,
                                                      num_points=args.n_points,
                                                      output=args.output_name)

print_footer(common_output)
