#!/usr/bin/env python
import time

from mpi4py import MPI
from pyfaidx import Fasta
from termcolor import colored

from A2G.__version__ import version
from A2G.align2consensus import *


def split_n_yield(fasta, chunks):
    fa = Fasta(fasta)
    for seq in np.array_split(list(fa.values()), chunks):
        yield '\n'.join(['>%s/n%s' % (sq.long_name, sq) for sq in seq])


def sanity_checks(size, rank):
    # Do we only have one process?  If yes, then exit.
    if size == 1:
        print('You are running an MPI program with only one slot/task!')
        print('Are you using `mpirun` (or `srun` when in SLURM)?')
        print('If you are, then please use an `-n` of at least 2! (Or, when in'
              ' SLURM, use an `--ntasks` of at least 2.)')
        print('If you did all that, then your MPI setup may be bad.')
        sys.exit()

    # Is our world size over 999?  The output will be a bit weird.
    # NOTE: Only rank zero does anything, so we don't get duplicate output.
    if size >= 1000 and rank == 0:
        print('WARNING:  Your world size {} is over 999!'.format(size))
        print("The output formatting will be a little weird, but that's it.")

    # Sanity checks complete!


parser = argparse.ArgumentParser()
parser.add_argument('global_consensus', help='Sequence consensus of the '
                                             'global region, e.g. full COI')
parser.add_argument('local_consensus',
                    help='Sequence consensus of the local region, e.g. '
                         'Leray fragment')
parser.add_argument('fasta', help='fasta file with the focal sequences')
parser.add_argument('--out_prefix', action='store', default='A2G_aln',
                    help='Prefix of outputs')
parser.add_argument('--remove_duplicates', action='store_false',
                    help='Keep or remove duplicated sequences',
                    default=True)

args = parser.parse_args()

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
curr_time = time.time()
sanity_checks(size, rank)

if rank == 0:
    print('\nA2G_mpi version:', colored(version, None, attrs=["bold", "blink"]))
    print(colored('Copyright 2020', 'red', attrs=["bold"]), 'Jose Sergio Hleap\n')
    print("-" * 78)
    print(" Running on %d cores" % size)
    print("-" * 78)

aln = Align(gene_consensus=args.global_consensus, cpus=1,
            amplicon_consensus=args.local_consensus, no_write=True,
            out_prefix=args.out_prefix, quiet=True)

if rank == 0:
    data = list(split_n_yield(args.fasta, size))
else:
    data = None

data = comm.scatter(data, root=0)
s_names = ','.join([x.split()[0] for x in data.split('\n')])
print('rank', rank, 'has sequences:', s_names)
aln.query = data
results = aln.run()
# newData = comm.gather(results, root=0)
# if newData is not None:
#     full, _ = zip(*newData)
newData = comm.gather(results, root=0)
if rank == 0:
    fasta, subset = zip(*newData)
    # for i in newData:
    #     print('Head master data:', i[:20])
    print('MASTER LENGH', len(newData))
    print('FASTA LENGH', len(fasta))
    print('SUBSET LENGH', len(subset))
    # with open('result.aln', 'w') as f:
    #     f.write('\n'.join(full))
    # Flatten list of lists.
    # results = [_i for temp in newData for _i in temp]
    # fasta, subset = zip(*results)
    with open('%s_aligned.fasta' % args.out_prefix, 'w') as o, open(
            '%s_aligned.withoutliers' % args.out_prefix, 'w') as w:
        o.write('\n'.join(fasta))
        w.write('\n'.join(subset))
    print('Elapsed time:', time.time() - curr_time)
