# AUTOGENERATED! DO NOT EDIT! File to edit: notebooks/cif2tiff.ipynb (unless otherwise specified).

__all__ = ['convert', 'convert_cmd']

# Cell
import math
import javabridge
import bioformats as bf
import numpy as np
import logging
from tqdm import tqdm
from pathlib import Path
import numpy
import click
import flowio
from PIL import Image
from multiprocessing import Pool

# Internal Cell

def _is_image(_series, _r):
    _r.rdr.setSeries(_series)
    return _r.rdr.getPixelType() != 1

# Internal Cell
def setup_directory_structure(out, fcs):
    data = flowio.FlowData(fcs)

    label_idx = None
    for k, v in data.channels.items():
        if v["PnN"] == "label":
            label_idx = int(k)

    assert label_idx is not None, "No label column present in fcs file"

    labels = numpy.reshape(data.events, (-1, data.channel_count))[:, label_idx-1].astype(int)

    out = Path(out)
    for label in numpy.unique(labels):
        (Path(out) / str(label)).mkdir(exist_ok=True, parents=True)

    return labels

# Internal Cell

def process_chunk(images, crange):
    upper_bound = images[0].reshape(len(crange), -1).max(axis=1)
    lower_bound = images[0].reshape(len(crange), -1).min(axis=1)

    counter = 1
    for i in images[1:]:
        a = images[counter].reshape(len(crange), -1).min(axis=1)
        b = images[counter].reshape(len(crange), -1).max(axis=1)
        lower_bound = np.where(a < lower_bound, a, lower_bound)
        upper_bound = np.where(b > upper_bound, b, upper_bound)

        counter += 1

    return lower_bound, upper_bound

def get_min_max_in_file(reader, r_length, crange, nprocs, do_tqdm=True):

    chunks = numpy.array_split(numpy.arange(0, r_length, step=2), nprocs)
    image_chunks = [None]*len(chunks)

    with Pool(processes=nprocs) as pool:
        results = []
        for i, chunk in tqdm(enumerate(chunks), position=0, leave=False, total=len(chunks)):

            images = [None]*len(chunk)
            for j, series in tqdm(enumerate(chunk), total=len(chunk), position=1, leave=False, mininterval=10):
                first = reader.read(c=0, series=series)
                im = numpy.empty(shape=(len(crange),) + first.shape, dtype=first.dtype)
                mask = numpy.empty(shape=(len(crange),) + first.shape, dtype=first.dtype)

                im[0] = first
                for c in crange[1:]:
                    im[c] = reader.read(c=c, series=series)
                for c in crange:
                    mask[c] = reader.read(c=c, series=series+1)
                images[j] = im*mask

            image_chunks[i] = images
            results.append(pool.apply_async(process_chunk, args=(images, crange)))
            print(f"Submitted chunk {i}")

        lower_bound, upper_bound = None, None
        for i, result in enumerate(results):
            print(f"Waiting for result {i}")
            a, b = result.get()

            if lower_bound is None:
                lower_bound = a
            else:
                lower_bound = np.where(a < lower_bound, a, lower_bound)
            if upper_bound is None:
                upper_bound = b
            else:
                upper_bound = np.where(b > upper_bound, b, upper_bound)

    return lower_bound, upper_bound, image_chunks

# Cell

def convert(cif_files, fcs_files, output, channels, debug, nproc=1, external_jvm_control=False):

    if debug:
        logging.basicConfig()
        logging.getLogger().setLevel(logging.INFO)

    output = Path(output)

    logger = logging.getLogger(__name__)

    try:

        if not external_jvm_control:
            logger.debug("Starting Java VM")
            javabridge.start_vm(class_path=bf.JARS, run_headless=True, max_heap_size="12G")
            logger.debug("Started Java VM")

        counter = 0
        for cif, fcs in zip(cif_files, fcs_files):
            print(f"Processing {cif}")

            labels = setup_directory_structure(output, fcs)

            reader = bf.formatreader.get_image_reader("reader", path=cif)
            r_length = javabridge.call(reader.metadata, "getImageCount", "()I")
            num_channels = javabridge.call(reader.metadata, "getChannelCount", "(I)I", 0)

            if debug:
                r_length=100

            if len(channels) == 0:
                crange = [i for i in range(num_channels)]
            else:
                crange = np.array(channels)-1

            lower_bound, upper_bound, image_chunks = get_min_max_in_file(reader, r_length, crange, nproc)
            lower_bound = lower_bound.reshape(len(crange), 1, 1)
            upper_bound = upper_bound.reshape(len(crange), 1, 1)

            a = 0
            for images in tqdm(image_chunks, position=0, leave=False, mininterval=10):
                for im in images:
                    im = (((im - lower_bound) / (upper_bound - lower_bound))*(2**16)).astype(numpy.uint16)
                    pillow_img = Image.fromarray(im[0], mode="I;16")
                    pillow_img.save(
                        output / str(labels[a]) / f"{counter}.tiff",
                        append_images = [Image.fromarray(im[i], mode="I;16") for i in range(1, im.shape[0])],
                        save_all = True
                    )
                    counter+=1
                    a+1

    finally:
        if not external_jvm_control:
            javabridge.kill_vm()

# Cell
@click.command(name="cif2tiff")
@click.argument("cif", type=click.Path(exists=True, file_okay=False))
@click.argument("output", type=click.Path(exists=False, file_okay=False))
@click.option("--channels", multiple=True, type=int, default=[], help="Images from these channels will be extracted. Default is to extract all. 1-based index.")
@click.option("--debug", is_flag=True, flag_value=True, help="Show debugging information. Limits output to 100 first cells.", default=False)
@click.option("--nproc", type=int, default=-1, help="Amount of processes to use.")
def convert_cmd(cif, output, channels, debug, nproc):

    if nproc == -1:
        from multiprocessing import cpu_count
        nproc = cpu_count()

    import glob
    cif_files = glob.glob(str(Path(cif) / "**" / "*.cif"))
    fcs_files = [str(Path(f).with_suffix(".fcs")) for f in cif_files]

    convert(cif_files, fcs_files, output, channels, debug, nproc, do_tqdm=False)