#!/usr/bin/env python3

import argparse
import functools
import os
import numpy
import torch

parser = argparse.ArgumentParser()
# In/Out
parser.add_argument("-i", "--in_name", default=None, help="Can be either a .npy file or a .dcd molecular dynamics "
                                                          "trajectory. If you are providing a .dcd file, you "
                                                          "should also provide a PDB and optional selections.")
parser.add_argument("--pdb", help="If using directly a dcd file, we need to add a PDB for selection")
parser.add_argument('--select', default='polymer', type=str, help='Atoms to select')
parser.add_argument('--select_align', type=str, help='Atoms to select for structural alignment')
parser.add_argument("-o", "--out_name", default='som.p', help="name of pickle to dump")
# SOM
parser.add_argument("-m", "--m", type=int, default=50, help="The width of the som")
parser.add_argument("-n", "--n", type=int, default=50, help="The height of the som")
parser.add_argument("--periodic", default=False, action='store_true', help="if set, periodic topology is used")
parser.add_argument('-j', '--jax', help='To use the jax version', action='store_true')
# Optim
parser.add_argument("--n_epoch", type=int, default=5, help="The number of iterations")
parser.add_argument("--batch_size", type=int, default=100, help="The batch size to use")
parser.add_argument("--num_workers", type=int, default=os.cpu_count() - 1, help="The number of workers to use")
parser.add_argument("--alpha", type=float, default=None, help="The initial learning rate")
parser.add_argument("--sigma", type=float, default=None, help="The initial sigma for the convolution")
parser.add_argument("--scheduler", default='linear', help="Which scheduler to use, can be linear, exp or half")
args, _ = parser.parse_known_args()

# Setup the data part
if args.in_name.endswith('.dcd'):
    from quicksom.dcd_reader import DCDataset

    dataset = DCDataset(pdbfilename=args.pdb, dcdfilename=args.in_name, selection=args.select,
                        selection_alignment=args.select_align)
    dim = dataset.dim
    dataset = torch.utils.data.DataLoader(dataset,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          pin_memory=True,
                                          num_workers=args.num_workers)
else:
    dataset = numpy.load(args.in_name)
    dim = dataset.shape[1]

if args.jax:
    from quicksom.somax import SOM
    import jax

    device = jax.devices()[0]
else:
    from quicksom.som import SOM

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

som = SOM(args.m,
          args.n,
          dim,
          alpha=args.alpha,
          sigma=args.sigma,
          sched=args.scheduler,
          precompute=True,
          periodic=args.periodic,
          n_epoch=args.n_epoch,
          device=device)
learning_error = som.fit(dataset=dataset, batch_size=args.batch_size, num_workers=args.num_workers)
som.save_pickle(args.out_name)
