# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03_vis.core.ipynb (unless otherwise specified).

__all__ = ['PRIMARY_COLOR', 'ACCENT_COLOR', 'DEFAULT_STYLE', 'centimeter', 'set_style', 'plot_spectra', 'summary_plot',
           'plot_validation_curve', 'plot_learning_curve', 'plot_capacity']

# Cell
#nbdev_comment from __future__ import annotations
from ..data.loading import load_kssl
from ..data.selection import get_y_by_order
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.ticker import EngFormatter
from matplotlib import ticker
from fastcore.test import *

# Cell
PRIMARY_COLOR = '#333'
ACCENT_COLOR = 'firebrick'
DEFAULT_STYLE = {
    'axes.linewidth': 0.5,
    'axes.facecolor': 'white',
    'axes.ymargin': 0.11,
    'font.size': 8,

    'axes.spines.bottom': True,
    'axes.spines.left': False,
    'axes.spines.right': False,
    'axes.spines.top': False,
    'axes.grid': True,

    'grid.color': 'black',
    'grid.linewidth': 0.2,
    'grid.linestyle': '-',

    'xtick.bottom': True,
    'xtick.top': False,
    'xtick.direction': 'out',
    'xtick.major.size': 5,
    'xtick.major.width': 1,
    'xtick.minor.size': 3,
    'xtick.minor.width': 0.5,
    'xtick.minor.visible': True,

    'ytick.left': True,
    'ytick.right': False,
    'ytick.direction': 'in',
    'ytick.major.size': 5,
    'ytick.major.width': 1,
    'ytick.minor.size': 3,
    'ytick.minor.width': 0.5,
    'ytick.minor.visible': True
}

centimeter = 1/2.54  # centimeters in inches

# Cell
def set_style(style:dict # Dictionary of plt.rcParams
             ):
    for k, v in style.items():
        plt.rcParams[k] = v

# Cell
def plot_spectra(X:np.ndarray, # Spectra (n_samples, n_wavenumbers)
                 X_names:np.ndarray, # Wavenumbers (n_wavenumbers)
                 figsize=(18, 5), # Wavenumbers
                 sample=20): # Size of random subset
    """Plot Mid-infrared spectra"""
    fig, ax = plt.subplots(figsize=figsize)
    idx = np.random.randint(X.shape[0], size=sample)
    ax.set_xlim(np.max(X_names), np.min(X_names))
    ax.set(xlabel='Wavenumber', ylabel='Absorbance')
    ax.set_axisbelow(True)
    ax.grid(True, which='both')
    _ = ax.plot(X_names, X[idx, :].T)

# Cell
def summary_plot(y:np.ndarray, # Target variable (n_samples)
                 depth_order:np.ndarray, # Soil and Depth (n_samples, 2)
                 tax_lookup:dict, # {'alfisols': 0,'mollisols': 1, ...}
                ):
    p = plt.rcParams
    p["font.size"] = 8

    p["axes.linewidth"] = 1
    p["axes.facecolor"] = "white"
    p["axes.ymargin"] = 0.1
    p["axes.spines.bottom"] = True
    p["axes.spines.left"] = False
    p["axes.spines.right"] = False
    p["axes.spines.top"] = False

    p["axes.grid"] = True
    p["grid.color"] = "black"
    p["grid.linewidth"] = 0.2
    p['grid.linestyle'] = '--'

    p["ytick.left"] = True
    p["ytick.right"] = True
    p["ytick.major.size"] = 0
    p["ytick.major.width"] = 1
    p["ytick.minor.size"] = 0
    p["ytick.minor.width"] = 0.5
    p["ytick.minor.visible"] = False

    fig, (ax1, ax2) = plt.subplots(nrows=1,ncols=2, gridspec_kw={'width_ratios': [2, 2]},
                                   sharey=True, figsize=(15*centimeter, 8*centimeter), dpi=600)

    y_by_order, count_by_order, idx_order = get_y_by_order(y, depth_order[:, 1], tax_lookup)
    y_labels = np.array([k.capitalize() for k, v in tax_lookup.items()])[idx_order]

    rects = ax1.barh(y_labels, count_by_order,
                     align='center',
                     height=0.65,
                     color=PRIMARY_COLOR)


    for i, v in enumerate(count_by_order):
        offset = 100 if i < len(count_by_order)-1 else -4000
        color = PRIMARY_COLOR if i < len(count_by_order)-1 else "white"
        ax1.text(v + offset, i - 0.01 , str(v),
                verticalalignment='center',
                horizontalalignment='right',
                color=color, fontweight='normal', size=6)

    for ax in [ax1, ax2]:
        ax.xaxis.set_major_locator(ticker.MaxNLocator(4))
        ax.xaxis.set_minor_locator(ticker.MaxNLocator(20))


    ax1.tick_params(axis='y', which='major', pad=30)
    ax1.set_xlabel('← Number of samples', loc='left')
    ax1.set_ylabel('Taxonomic order')
    formatter1 = EngFormatter(places=0, sep="\N{THIN SPACE}")  # U+2009
    ax1.xaxis.set_major_formatter(formatter1)
    ax1.set_yticklabels(y_labels, fontdict={'horizontalalignment': 'center'})
    ax1.yaxis.tick_right()
    ax1.invert_xaxis()
    ax1.set_title('(a)', loc='left')

    boxplot = ax2.boxplot(y_by_order, sym='.', positions=range(13), vert=False,
                        patch_artist=True)
    ax2.set_xlabel('Exchangeable Potassium ($cmol(+)kg^{-1}$) →', loc='right')

    for median in boxplot['medians']:
        median.set(color='white', linewidth=1)

    for box in boxplot['boxes']:
            box.set(facecolor=PRIMARY_COLOR)

    for flier in boxplot['fliers']:
        flier.set(markersize='1.5', markeredgecolor="tab:red", alpha=0.3, zorder=-1)

    ax2.set_xscale('log')
    ax2.set_title('(b)', loc='left')
    ax2.yaxis.set_ticks_position('none')

    plt.tight_layout()
    #plt.savefig(os.path.join(IMG_PATH, 'data-summary.png'), dpi=600, transparent=True, format='png')

# Cell
def plot_validation_curve(x, losses, ax=None, plot_kwargs={}, fill_between_kwargs={}):
    Y = np.mean(np.array(losses), axis=0)
    SD = np.std(np.array(losses), axis=0)
    ax.fill_between(x, Y + SD, Y - SD, **fill_between_kwargs)
    ax.plot(x, Y, **plot_kwargs)
    return(ax)

# Cell
def plot_learning_curve(x, losses_train, losses_valid, ax=None,  train_kwargs={}, valid_kwargs={}):
    if ax is None:
        ax = plt.gca()
    ax.plot(x, losses_train, label='Training', **train_kwargs)
    ax.plot(x, losses_valid, label='Validation', **valid_kwargs)
    ax.set_yscale('log')
    ax.set_xscale('log')
    return(ax)

# Cell
def plot_capacity(x, capacity, ax=None, **kwargs):
    if ax is None:
            ax = plt.gca()
    ax.bar(x, capacity, width=0.15*np.array(x), color=PRIMARY_COLOR, zorder=99, **kwargs)
    ax.set_yscale('log')
    ax.set_xscale('log')
    # ax.spines.bottom.set_visible(True)
    return(ax)