#!/usr/bin/env python3
# -*- coding: UTF8 -*-

# Author: Guillaume Bouvier -- guillaume.bouvier@pasteur.fr
# https://research.pasteur.fr/en/member/guillaume-bouvier/
# 2020-09-21 11:20:44 (UTC+0200)

import sys
import os
import pickle
import numpy
import matplotlib.pyplot as plt
import argparse

parser = argparse.ArgumentParser(description='Plot flow for time serie clustering.')
# In/Out
parser.add_argument("-s", "--som_name", default='som.p', help="name of the SOM pickle to load")
parser.add_argument("-b", "--bmus", default='out_bmus.txt', help="BMU file to plot")
parser.add_argument("-d", "--data", required=True, help="Data file to project")
parser.add_argument("--no_gui", default=False, action='store_true', help='Do not open graphical window')
args, _ = parser.parse_known_args()

som = pickle.load(open(args.som_name, 'rb'))
bmus = numpy.genfromtxt(args.bmus, usecols=(0, 1))
data = numpy.genfromtxt(args.data)
npts = bmus.shape[0]
dim1, dim2 = som.m, som.n
projection = numpy.ones((dim1, dim2)) * numpy.inf
for i in range(dim1):
    sys.stdout.write(f'{i+1}/{dim1}\r')
    sys.stdout.flush()
    for j in range(dim2):
        inds = numpy.where((bmus == [i, j]).all(axis=1))[0]
        if len(inds) > 0:
            projection[i, j] = data[inds].mean()
projection[numpy.isinf(projection)] = numpy.nan
outbasename = os.path.splitext(args.data)[0]
numpy.save(f'{outbasename}_project.npy', projection)
plt.matshow(projection)
plt.colorbar()
if args.no_gui:
    plt.savefig(f'{outbasename}_project.png')
else:
    plt.show()
