import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from natsort import natsorted

from typing import Tuple, List
from pegasusio import UnimodalData



def plot_hto_hist(hashing_data: UnimodalData, attr: str, out_file: str, alpha: float = 0.5, dpi: int = 500, figsize: Tuple[float, float] = None) -> None:
    idx_signal = np.isin(hashing_data.obs[attr], "signal")
    signal = hashing_data.obs.loc[idx_signal, "counts"]
    background = hashing_data.obs.loc[~idx_signal, "counts"]
    bins = np.logspace(0, np.log10(max(signal.max(), background.max())), 501)
    plt.hist(background, bins, alpha=alpha, label="background", log=True)
    plt.hist(signal, bins, alpha=alpha, label="signal", log=True)
    plt.legend(loc="upper right")
    ax = plt.gca()
    ax.set_xscale("log")
    ax.set_xlabel("Number of hashtag UMIs (log10 scale)")
    ax.set_ylabel("Number of cellular barcodes (log10 scale)")
    if figsize is not None:
        plt.gcf().set_size_inches(*figsize)
    plt.savefig(out_file, dpi=dpi)
    plt.close()


def plot_rna_hist(
    rna_data: UnimodalData, out_file: str, plot_attr: str = "n_counts", cat_attr: str = "demux_type", dpi: int = 500, figsize: Tuple[float, float] = None
) -> None:
    bins = np.logspace(
        np.log10(min(rna_data.obs[plot_attr])), np.log10(max(rna_data.obs[plot_attr])), 101
    )
    cat_vec = rna_data.obs[cat_attr]
    ax = plt.gca()
    if cat_attr == "demux_type":
        ax.hist(
            rna_data.obs.loc[np.isin(cat_vec, "singlet"), plot_attr],
            bins,
            alpha=0.5,
            label="singlet",
        )
        ax.hist(
            rna_data.obs.loc[np.isin(cat_vec, "doublet"), plot_attr],
            bins,
            alpha=0.5,
            label="doublet",
        )
        ax.hist(
            rna_data.obs.loc[np.isin(cat_vec, "unknown"), plot_attr],
            bins,
            alpha=0.5,
            label="unknown",
        )
    ax.legend(loc="upper right")
    ax.set_xscale("log")
    ax.set_xlabel("Number of RNA UMIs (log10 scale)")
    ax.set_ylabel("Number of cellular barcodes")
    if figsize is not None:
        plt.gcf().set_size_inches(*figsize)
    plt.savefig(out_file, dpi=dpi)
    plt.close()


def plot_bar(heights: List[float], tick_labels: List[str], xlabel: str, ylabel: str, out_file: str, dpi: int = 500, figsize: Tuple[float, float] = None) -> None:
    plt.bar(
        x=np.linspace(0.5, heights.size - 0.5, heights.size),
        height=heights,
        tick_label=tick_labels,
    )
    ax = plt.gca()
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    if figsize is not None:
        plt.gcf().set_size_inches(*figsize)
    rotation = 90 if max([len(x) for x in tick_labels]) > 6 else 0
    plt.tick_params(axis="x", labelsize=7, labelrotation=rotation)
    plt.tight_layout()
    plt.savefig(out_file, dpi=dpi)
    plt.close()


def plot_gene_violin(
    data: UnimodalData,
    gene_name: str,
    out_file: str,
    title: str = None,
    dpi: int = 500,
    figsize: Tuple[float, float] = None,
    linewidth: float = None,
    inner: str = "box",
) -> None:
    df = pd.DataFrame(
        data[:, gene_name].X.toarray(),
        index=data.obs_names,
        columns=[gene_name],
    )
    df["assignment"] = data.obs["demux_type"].astype(str)
    idx_singlet = np.isin(data.obs["demux_type"], "singlet")
    singlets = data.obs.loc[idx_singlet, "assignment"].astype(str)
    df.loc[idx_singlet, "assignment"] = singlets
    categories = natsorted(singlets.unique())
    categories.extend(["doublet", "unknown"])
    df["assignment"] = pd.Categorical(df["assignment"], categories=categories)
    xlabel = "assignment"
    ylabel = gene_name

    sns.violinplot(
        x=xlabel, y=ylabel, data=df, linewidth=linewidth, cut=0, inner=inner
    )

    ax = plt.gca()
    ax.grid(False)
    ax.set_ylabel("log(TP100K+1)")
    if title is not None:
        ax.set_title(title)

    if figsize is not None:
        plt.gcf().set_size_inches(*figsize)

    rotation = 90 if max([len(x) for x in df[xlabel].unique()]) > 6 else 0
    plt.tick_params(axis="x", labelsize=7, labelrotation=rotation)
    plt.tight_layout()
    plt.savefig(out_file, dpi=dpi)
    plt.close()


def plot_heatmap(vec1: List[int], vec2: List[int], out_file: str, dpi: int = 500, xlabel: str = "", ylabel: str = "") -> None:
    df = pd.crosstab(vec1, vec2)
    df.columns.name = df.index.name = ""

    ax = plt.gca()
    ax.xaxis.tick_top()
    ax = sns.heatmap(df, annot=True, fmt="d", cmap="inferno", ax=ax)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    plt.tight_layout()
    plt.savefig(out_file, dpi=500)
    plt.close()


def plot_dataframe_bar(df: pd.DataFrame, ylabel: str, out_file: str, dpi: int = 500, figsize: Tuple[float, float] = None) -> None:
    if df.shape[1] == 1:
        df.plot.bar(legend=False)
    else:
        df.plot.bar()
    ax = plt.gca()
    ax.set_ylabel(ylabel)
    if figsize is not None:
        plt.gcf().set_size_inches(*figsize)
    plt.savefig(out_file, dpi=dpi)
    plt.close()
