#!/usr/bin/env python3

import argparse
import functools
import os
import numpy

import torch

parser = argparse.ArgumentParser(description='All the indices are starting from 1.')
# In/Out
parser.add_argument("-i", "--in_name", default=None, help="Can be either a .npy file or a .dcd molecular dynamics "
                                                          "trajectory. If you are providing a .dcd file, you "
                                                          "should also provide a PDB and optional selections.")
parser.add_argument("--pdb", help="If using directly a dcd file, we need to add a PDB for selection")
parser.add_argument('--select', default='polymer', type=str, help='Atoms to select')
parser.add_argument('--select_align', type=str, help='Atoms to select for structural alignment')
parser.add_argument("-o", "--out_name", help="name of txt to dump")
parser.add_argument("-s", "--som_name", default='som.p', help="name of pickle to load")
parser.add_argument("--recompute_cluster", default=False, action='store_true', help="if set, periodic topology is used")
parser.add_argument("--batch_size", type=int, default=100, help="The batch size to use")
parser.add_argument("--num_workers", type=int, default=os.cpu_count() - 1, help="The number of workers to use")
parser.add_argument("--subset", default=False, action='store_true',
                    help="Use the user defined clusters instead of the expanded partition.")
parser.add_argument('-j', '--jax', help='To use the jax version', action='store_true')
args, _ = parser.parse_known_args()
if args.out_name is None:
    args.out_name = os.path.splitext(args.in_name)[0]

if args.jax:
    from quicksom.somax import SOM
    import jax

    device = jax.devices()[0]
else:
    from quicksom.som import SOM

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
som = SOM.load_pickle(inname=args.som_name, device=device)

if args.recompute_cluster:
    som.cluster_att = None

if args.in_name is not None:
    if args.in_name.endswith('.dcd'):
        from quicksom.dcd_reader import DCDataset

        dataset = DCDataset(pdbfilename=args.pdb, dcdfilename=args.in_name, selection=args.select,
                            selection_alignment=args.select_align)
        dim = dataset.dim
        dataset = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=args.num_workers)
    else:
        dataset = numpy.load(args.in_name)
    predicted_clusts, errors = som.predict_cluster(dataset,
                                                   batch_size=args.batch_size,
                                                   user=args.subset)
else:
    predicted_clusts, errors = som.predict_cluster(batch_size=args.batch_size,
                                                   user=args.subset)

# numpy.savetxt(args.out_name, predicted_clusts, fmt='%d')
codebook = numpy.asarray(som.codebook)
bmus_flatten = numpy.ravel_multi_index((numpy.int_(som.bmus[:, 0]), numpy.int_(som.bmus[:, 1])), (som.m, som.n))
index_min_error = codebook[bmus_flatten]
index = numpy.arange(len(som.bmus))
out_arr = numpy.c_[som.bmus, som.error, index, index_min_error, predicted_clusts, bmus_flatten]
out_fmt = ['%d', '%d', '%.4g', '%d', '%d', '%d', '%d']
out_header = f'#bmu1 #bmu2 #error #index #index_min_error #cluster #bmu_flatten_{som.m}x{som.n}'
numpy.savetxt(f"{os.path.splitext(args.out_name)[0]}_bmus.txt", out_arr,
              fmt=out_fmt, header=out_header, comments='')
out_arr_codebook = out_arr[index == index_min_error]
numpy.savetxt(f"{os.path.splitext(args.out_name)[0]}_codebook.txt", out_arr_codebook,
              fmt=out_fmt, header=out_header, comments='')
with open(f"{os.path.splitext(args.out_name)[0]}_clusters.txt", 'w') as outfile:
    for cid in numpy.unique(predicted_clusts):
        if cid > 0:
            inds = numpy.where(predicted_clusts == cid)[0] + 1
            outfile.write(f"{' '.join(['%d' % e for e in inds])}\n")
