# AUTOGENERATED! DO NOT EDIT! File to edit: ../workflow/notebooks/core/00_core.ipynb.

# %% auto 0
__all__ = ['plot_gate_zarr', 'plot_gate_zarr_channels', 'plot_gate_czi']

# %% ../workflow/notebooks/core/00_core.ipynb 4
from .common import *


# %% ../workflow/notebooks/core/00_core.ipynb 5
import matplotlib.gridspec as gridspec
import yaml
import zarr
from aicsimageio import AICSImage
from matplotlib.colors import LinearSegmentedColormap, Normalize

from scip.masking import remove_regions_touching_border


# %% ../workflow/notebooks/core/00_core.ipynb 7
def plot_gate_zarr(
    sel,
    df,
    mask,
    maxn=200,
    sort=None,
    channels=[0],
    bbox=True,
    ncols=5,
    cmaps=None,
    qq=(0, 1),
):
    df = df.loc[sel]

    if len(df) > maxn:
        df = df.sample(n=maxn)

    if sort is not None:
        df = df.sort_values(by=sort)

    i = 0
    pixels = [None] * len(df)
    extent = numpy.full(
        (df.shape[0], 2, len(channels)), dtype=float, fill_value=numpy.nan
    )
    for path, gdf in df.groupby("meta_path"):
        z = zarr.open(path, mode="r")
        for (idx, r) in gdf.iterrows():
            pixels_ = z[r["meta_zarr_idx"]]
            pixels_ = pixels_.reshape(z.attrs["shape"][r["meta_zarr_idx"]])[channels]

            if bbox:
                minr, minc, maxr, maxc = (
                    int(r[f"meta_{mask}_bbox_minr"]),
                    int(r[f"meta_{mask}_bbox_minc"]),
                    int(r[f"meta_{mask}_bbox_maxr"]),
                    int(r[f"meta_{mask}_bbox_maxc"]),
                )
                pixels_ = pixels_[:, minr:maxr, minc:maxc]

            extent[i, 0] = numpy.quantile(
                pixels_.reshape(pixels_.shape[0], -1), q=qq[0], axis=1
            )
            extent[i, 1] = numpy.quantile(
                pixels_.reshape(pixels_.shape[0], -1), q=qq[1], axis=1
            )

            pixels[df.index.get_loc(idx)] = pixels_

            i += 1

    min_ = extent[:, 0].min(axis=0)
    max_ = extent[:, 1].max(axis=0)

    ncols = min(df.shape[0], ncols)
    nrows = int(math.ceil(len(df) / ncols))
    fig, axes = plt.subplots(
        ncols=ncols,
        nrows=nrows,
        dpi=50,
        figsize=(ncols * 2 * len(channels), nrows * 2),
        squeeze=False,
    )
    axes = axes.ravel()

    if cmaps == None:
        cmaps = [plt.get_cmap("viridis")] * len(channels)

    for i, pixels_ in enumerate(pixels):
        tmp = (pixels_ - min_[:, numpy.newaxis, numpy.newaxis]) / (max_ - min_)[
            :, numpy.newaxis, numpy.newaxis
        ]
        pixels[i] = numpy.hstack([cm(p) for cm, p in zip(cmaps, tmp)])

    for i, (ax, pixels_) in enumerate(zip(axes, pixels)):
        ax.imshow(pixels_)
    for ax in axes:
        ax.set_axis_off()


# %% ../workflow/notebooks/core/00_core.ipynb 12
def plot_gate_zarr_channels(
    selectors,
    df,
    mask,
    maxn=20,
    sort=None,
    show_mask=False,
    main_channel=3,
    smooth=0.75,
    channel_ind=[0],
    channel_names=["c"],
):

    dfs = []
    for i, sel in enumerate(selectors):
        tmp_df = df[sel].copy()

        if len(tmp_df) > maxn:
            tmp_df = tmp_df.sample(n=maxn)

        if sort is not None:
            tmp_df = tmp_df.sort_values(by=sort)

        tmp_df["sel"] = i
        dfs.append(tmp_df)
    df = pandas.concat(dfs)

    nchannels = len(channel_ind)

    images = {}
    masks = {}
    extent = numpy.empty(shape=(nchannels, 2), dtype=float)
    extent[:, 0] = numpy.inf
    extent[:, 1] = -numpy.inf

    for path, gdf in df.groupby("meta_path"):
        z = zarr.open(path, mode="r")
        for (idx, r) in gdf.iterrows():
            pixels = z[r["meta_zarr_idx"]]
            pixels = pixels.reshape(z.attrs["shape"][r["meta_zarr_idx"]])[channel_ind]

            minr, minc, maxr, maxc = (
                int(r[f"meta_{mask}_bbox_minr"]),
                int(r[f"meta_{mask}_bbox_minc"]),
                int(r[f"meta_{mask}_bbox_maxr"]),
                int(r[f"meta_{mask}_bbox_maxc"]),
            )

            images[r["sel"]] = images.get(r["sel"], []) + [
                pixels[:, minr:maxr, minc:maxc]
            ]
            if show_mask:
                m = li.get_mask(
                    dict(pixels=pixels), main_channel=main_channel, smooth=smooth
                )
                m = remove_regions_touching_border(m, bbox_channel_index=main_channel)
                arr = m["mask"][:, minr:maxr, minc:maxc]
            else:
                arr = numpy.full(shape=pixels.shape, dtype=bool, fill_value=True)[
                    :, minr:maxr, minc:maxc
                ]
            masks[r["sel"]] = masks.get(r["sel"], []) + [
                numpy.where(arr, numpy.nan, arr)
            ]

            p = numpy.where(arr, pixels[:, minr:maxr, minc:maxc], numpy.nan)
            extent[:, 0] = numpy.nanmin(
                numpy.array(
                    [extent[:, 0], numpy.nanmin(p.reshape(nchannels, -1), axis=1)]
                ),
                axis=0,
            )
            extent[:, 1] = numpy.nanmax(
                numpy.array(
                    [extent[:, 1], numpy.nanmax(p.reshape(nchannels, -1), axis=1)]
                ),
                axis=0,
            )

    fig = plt.figure(dpi=75, figsize=(len(channel_ind) * 2.5, len(df) * 0.8))
    grid = gridspec.GridSpec(1, len(selectors), figure=fig, wspace=0.1)
    cmap = plt.get_cmap("viridis")
    norms = [Normalize(vmin=a, vmax=b) for a, b in extent]

    gs = {k: grid[0, k].subgridspec(len(v), nchannels) for k, v in images.items()}
    for k, v in images.items():
        for i, image in enumerate(v):
            for j, (p, m, norm) in enumerate(zip(image, masks[k][i], norms)):
                ax = plt.Subplot(fig, gs[k][i, j])
                ax.imshow(cmap(norm(p)))
                if show_mask:
                    ax.imshow(m, alpha=0.3, cmap="Blues")
                ax.set_axis_off()
                fig.add_subplot(ax)
                if i == 0:
                    ax.set_title(channel_names[j])


# %% ../workflow/notebooks/core/00_core.ipynb 13
def plot_gate_czi(
    sel, df, maxn=200, sort=None, channels=[0], masks_path_col=None, extent=None
):
    df = df.loc[sel]

    if len(df) > maxn:
        df = df.sample(n=maxn)

    if sort is not None:
        df = df.sort_values(by=sort)

    compute_extent = False
    if extent is None:
        compute_extent = True
        extent = numpy.full(
            (df.shape[0], 2, len(channels)), dtype=float, fill_value=numpy.nan
        )

    pixels = []
    masks = []
    ids = []
    i = 0
    for path, gdf in df.groupby(["meta_path"]):
        ai = AICSImage(path, reconstruct_mosaic=False)
        for scene, gdf2 in gdf.groupby(["meta_scene"]):
            ai.set_scene(scene)
            for tile, gdf3 in gdf2.groupby(["meta_tile"]):
                print(tile, scene, end=" ")
                for (idx, r) in gdf3.iterrows():
                    pixels_ = ai.get_image_data("CXY", Z=0, T=0, C=channels, M=tile)
                    minr, minc, maxr, maxc = (
                        int(r["meta_bbox_minr"]),
                        int(r["meta_bbox_minc"]),
                        int(r["meta_bbox_maxr"]),
                        int(r["meta_bbox_maxc"]),
                    )

                    if compute_extent:
                        extent[i, 0] = (
                            pixels_[:, minr:maxr, minc:maxc]
                            .reshape(pixels_.shape[0], -1)
                            .min(axis=1)
                        )
                        extent[i, 1] = (
                            pixels_[:, minr:maxr, minc:maxc]
                            .reshape(pixels_.shape[0], -1)
                            .max(axis=1)
                        )
                    pixels.append(pixels_[:, minr:maxr, minc:maxc])

                    if "meta_id" in r:
                        ids.append(r.meta_id)
                    else:
                        ids.append(idx[-1])

                    if masks_path_col is not None:
                        mask = numpy.load(r[masks_path_col])[:, minr:maxr, minc:maxc]
                        masks.append(mask)

                    i += 1

    if compute_extent:
        min_ = extent[:, 0].min(axis=0)
        max_ = extent[:, 1].max(axis=0)
    else:
        min_ = extent[:, 0]
        max_ = extent[:, 1]

    ncols = min(df.shape[0], 5)
    nrows = int(math.ceil(len(df) / ncols))
    fig, axes = plt.subplots(
        ncols=ncols, nrows=nrows, dpi=50, figsize=(ncols * 2 * len(channels), nrows * 2)
    )
    axes = axes.ravel()

    for i, (ax, pixels_, id_) in enumerate(zip(axes, pixels, ids)):
        ax.imshow(
            numpy.hstack(
                (pixels_ - min_[:, numpy.newaxis, numpy.newaxis])
                / (max_ - min_)[:, numpy.newaxis, numpy.newaxis]
            )
        )
        if len(masks) > 0:
            ax.imshow(
                numpy.hstack(numpy.where(masks[i] == id_, numpy.nan, 1)),
                cmap="Blues",
                alpha=0.3,
            )

    for ax in axes:
        ax.set_axis_off()

