import numpy as np
import pandas as pd
import sys
from itertools import chain
from scipy.cluster.hierarchy import linkage
import newick as nw

def get_kmers_per_col(mat):
    """From an adjacency matrix, generate a square matrix where each element
    of each column is the total number of K-mers in the genome indexed at that
    column.

    This is a convenience function used to generate a data structure that
    makes it easier to do vectorized computation of jaccard and overlap
    similarities.
    
    Parameters
    ----------
    mat
        a numpy array containing he adjacency matrix
    
    Returns
    -------
    array
        a numpy array containing the "K-mers per column" matrix
    """

    kmers_per_genome = np.diagonal(mat)
    return np.array([kmers_per_genome for _ in range(kmers_per_genome.shape[0])])


def adj_to_jaccard(adj_matrix):
    """Convert a raw adjacency matrix to Jaccard similarity values
    
    Parameters
    ----------
    adj_matrix
        pandas data frame containing the adjacency matrix
    
    Returns
    -------
    DataFrame
        pandas data framed containing the jaccard matrix
    """

    mat = adj_matrix.to_numpy()
    kmers_per_col = get_kmers_per_col(mat)
    jaccard = mat / (kmers_per_col + np.transpose(kmers_per_col) - mat)
    return pd.DataFrame(jaccard, index=adj_matrix.index, columns=adj_matrix.columns)


def adj_to_overlap(adj_matrix):
    """Convert a raw adjacency matrix to overlap similarity values
    
    Parameters
    ----------
    adj_matrix
        pandas data frame containing the adjacency matrix
    
    Returns
    -------
    DataFrame
        pandas data framed containing the overlap matrix
    """

    mat = adj_matrix.to_numpy()
    kmers_per_col = get_kmers_per_col(mat)
    overlap = mat / np.minimum(kmers_per_col, np.transpose(kmers_per_col))
    return pd.DataFrame(overlap, index=adj_matrix.index, columns=adj_matrix.columns)


def compute_linkage_matrix(adj_matrix, method='complete', optimal_ordering=True):
    """Compute a linkage matrix defining a hierarchical clustering
    
    Parameters
    ---------
    adj_matrix
        pandas data frame containing the adjacency matrix
    method : str
        linkage algorithm for hierarchical clustering. See
        `scipy.cluster.hierarchy.linkage` for details.
    optimal_ordering : bool
        If True, the linkage matrix will be reordered so that the distance
        between successive leaves is minimal. See
        `scipy.cluster.hierarchy.linkage`
    """

    adj_matrix_np = adj_matrix.to_numpy()
    similarity = adj_matrix_np[np.triu_indices(len(adj_matrix.index), k=1)]
    distance = adj_matrix_np.max() - similarity
    return linkage(distance, method=method, optimal_ordering=optimal_ordering)


def linkage_df(linkage_matrix):
    """Convert a linkage matrix to linkage data frame.
    This allows different dtypes and column headers.

    Parameters
    ----------
    linkage_matrix
        linkage matrix generated by `scipy.cluster.hierarchy.linkage`
    
    Returns
    -------
    DataFrame
        data frame containing the linkage matrix
    """

    df = pd.DataFrame(linkage_matrix)
    for n in 0, 1, 3:
        df[n] = df[n].astype(int)
    df.columns = 'clust_0', 'clust_1', 'dist', 'n_obs'
    return df


def linkage_to_newick(linkage_df, names):
    """Generate a NEWICK formatted tree from a linkage data frame

    Parameters
    ----------
    linkage_df
        pandas data frame containing the linkage matrix
    names
        iterable of genome names corresponding to clusters in linkage_df
    
    Returns
    -------
    str
        NEWICK formatted tree
    """
    
    n_genomes = len(names)
    names_dict = dict(enumerate(names))
    tree = list(chain.from_iterable(
        (nw.Node(name=names_dict.get(clust_0, str(clust_0)), length=str(dist)),
         nw.Node(name=names_dict.get(clust_1, str(clust_1)), length=str(dist)))
        for _, clust_0, clust_1, dist, _ in linkage_df.itertuples()))
    tree.append(nw.Node(name=str(2*n_genomes-2), length=tree[-1]._length))
    nodes_dict = {n.name: n for n in tree}
    for i in range(n_genomes - 1):
        nodes_dict[str(i+n_genomes)].add_descendant(tree[2*i])
        nodes_dict[str(i+n_genomes)].add_descendant(tree[2*i+1])
        l0 = float(nodes_dict[str(i+n_genomes)]._length)
        l1 = float(tree[2*i]._length)
        nodes_dict[str(i+n_genomes)]._length = str(l0 - l1)
    tree[-1].name = None
    tree[-1]._length = None
    return nw.dumps(tree[-1])


def tree(adj_matrix, newick: bool = False, metric: str = 'intersection',
         method: str = 'complete', optimal_ordering: bool = True,
         transformed_matrix = None):
    """Generate a hierarchical clustering tree from an adjacency matrix

    Parameters
    ----------
    adj_matrix
        a pandas data frame containing an adjacency matrix
    newick : bool
        if True, print a NEWICK-formatted tree. Otherwise, print a linkage
        matrix
    metric : str
        similarity metric to use for clustering. must be "intersection",
        "jaccard", or "overlap"
    method : str
        linkage algorithm for hierarchical clustering. See
        `scipy.cluster.hierarchy.linkage` for details.
    optimal_ordering : bool
        If True, the linkage matrix will be reordered so that the distance
        between successive leaves is minimal. See
        `scipy.cluster.hierarchy.linkage`
    """
    
    if metric == 'jaccard':
        adj_matrix = adj_to_jaccard(adj_matrix)
    elif metric == 'overlap':
        adj_matrix = adj_to_overlap(adj_matrix)
    if transformed_matrix:
        with open(transformed_matrix, 'w') as f:
            adj_matrix.to_csv(f)
    linkage_matrix = compute_linkage_matrix(adj_matrix, method=method,
                                     optimal_ordering=optimal_ordering)
    df = linkage_df(linkage_matrix)
    if newick:
        newick_tree = linkage_to_newick(df, adj_matrix.index)
        print(newick_tree)
    else:
        df.to_csv(sys.stdout, index=False, sep='\t')
