#!/usr/bin/env python3
# -*- coding: UTF8 -*-

# Author: Guillaume Bouvier -- guillaume.bouvier@pasteur.fr
# https://research.pasteur.fr/en/member/guillaume-bouvier/
# 2020-09-21 11:20:44 (UTC+0200)

import sys
import pickle
import numpy
import matplotlib.pyplot as plt
import argparse

parser = argparse.ArgumentParser(description='Plot flow for time serie clustering.')
# In/Out
parser.add_argument("-s", "--som_name", default='som.p', help="name of the SOM pickle to load")
parser.add_argument("-b", "--bmus", default='out_bmus.txt', help="BMU file to plot")
parser.add_argument("-n", "--norm", default=False, action="store_true", help="Normalize flow as unit vectors")
parser.add_argument("-m", "--mean", default=False, action="store_true", help="Average the flow by the number of structure per SOM cell")
parser.add_argument("--stride", default=1, help="Stride of the vectors field", type=int)
args, _ = parser.parse_known_args()


def unfold_bmus(som, bmus):
    return numpy.asarray([som.mapping[tuple(e)] for e in bmus])


def topoflow(bmu1, bmu2, som):
    dx = bmu2[:, 0] - bmu1[:, 0]
    sel = numpy.abs(dx) > som.m / 2
    dx[sel] = -numpy.sign(dx[sel]) * (som.m - numpy.abs(dx[sel]))
    dy = bmu2[:, 1] - bmu1[:, 1]
    sel = numpy.abs(dy) > som.n / 2
    dy[sel] = -numpy.sign(dy[sel]) * (som.n - numpy.abs(dy[sel]))
    return numpy.c_[dx, dy]


som = pickle.load(open(args.som_name, 'rb'))
umat = som.umat
bmus = numpy.genfromtxt(args.bmus, usecols=(0, 1))[::args.stride]
npts = bmus.shape[0]
dim1, dim2 = som.m, som.n
X, Y, U, V, uvals = [], [], [], [], []
for i in range(dim1):
    sys.stdout.write(f'{i+1}/{dim1}\r')
    sys.stdout.flush()
    for j in range(dim2):
        inds = numpy.where((bmus == [i, j]).all(axis=1))[0]
        if len(inds) > 0:
            inds_next = inds + 1
            sel = (inds_next < npts - 1)
            inds_next = inds_next[sel]
            inds = inds[sel]
            if som.periodic:
                if args.mean:
                    flow_ = (topoflow(bmus[inds], bmus[inds_next], som)).mean(axis=0)
                else:
                    flow_ = (topoflow(bmus[inds], bmus[inds_next], som)).sum(axis=0)
            else:
                if args.mean:
                    flow_ = (bmus[inds_next] - bmus[inds]).mean(axis=0)
                else:
                    flow_ = (bmus[inds_next] - bmus[inds]).sum(axis=0)
            if args.norm:
                flow_ /= numpy.linalg.norm(flow_, axis=0)
            X.append(i)
            Y.append(j)
            U.append(flow_[0])
            V.append(flow_[1])
            uvals.append(umat[i, j])
plt.matshow(som.umat)
plt.colorbar()
plt.quiver(Y, X, -numpy.asarray(V), -numpy.asarray(U), color='w')
plt.show()
