import click
import numpy as np
import caiman as cm
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.source_extraction.cnmf.params import CNMFParams
import psutil
import pandas as pd
import traceback
from pathlib import Path
from shutil import move as move_file
import os
import time
from datetime import datetime

if __name__ in ["__main__", "__mp_main__"]:  # when running in subprocess
    from mesmerize_core import set_parent_raw_data_path, load_batch
    from mesmerize_core.utils import IS_WINDOWS
else:  # when running with local backend
    from ..batch_utils import set_parent_raw_data_path, load_batch
    from ..utils import IS_WINDOWS


def run_algo(batch_path, uuid, data_path: str = None):
    algo_start = time.time()
    set_parent_raw_data_path(data_path)

    df = load_batch(batch_path)
    item = df[df["uuid"] == uuid].squeeze()

    input_movie_path = item["input_movie_path"]
    # resolve full path
    input_movie_path = str(df.paths.resolve(input_movie_path))

    output_dir = Path(batch_path).parent.joinpath(str(uuid))
    output_dir.mkdir(parents=True, exist_ok=True)

    params = item["params"]
    print("cnmfe params:", params)

    # adapted from current demo notebook
    if "MESMERIZE_N_PROCESSES" in os.environ.keys():
        try:
            n_processes = int(os.environ["MESMERIZE_N_PROCESSES"])
        except:
            n_processes = psutil.cpu_count() - 1
    else:
        n_processes = psutil.cpu_count() - 1
    # Start cluster for parallel processing
    c, dview, n_processes = cm.cluster.setup_cluster(
        backend="local", n_processes=n_processes, single_thread=False
    )

    try:
        fname_new = cm.save_memmap(
            [input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview
        )

        print("making memmap")
        gSig = params["main"]["gSig"][0]

        Yr, dims, T = cm.load_memmap(fname_new)
        images = np.reshape(Yr.T, [T] + list(dims), order="F")

        proj_paths = dict()
        for proj_type in ["mean", "std", "max"]:
            p_img = getattr(np, f"nan{proj_type}")(images, axis=0)
            proj_paths[proj_type] = output_dir.joinpath(
                f"{uuid}_{proj_type}_projection.npy"
            )
            np.save(str(proj_paths[proj_type]), p_img)

        downsample_ratio = params["downsample_ratio"]
        # in fname new load in memmap order C

        cn_filter, pnr = cm.summary_images.correlation_pnr(
            images[::downsample_ratio], swap_dim=False, gSig=gSig
        )

        pnr_output_path = output_dir.joinpath(f"{uuid}_pn.npy").resolve()
        cn_output_path = output_dir.joinpath(f"{uuid}_cn.npy").resolve()

        np.save(str(pnr_output_path), pnr, allow_pickle=False)
        np.save(str(cn_output_path), cn_filter, allow_pickle=False)

        d = dict()  # for output

        if params["do_cnmfe"]:
            cnmfe_params_dict = {
                "method_init": "corr_pnr",
                "n_processes": n_processes,
                "only_init": True,  # for 1p
                "center_psf": True,  # for 1p
                "normalize_init": False,  # for 1p
            }
            tot = {**cnmfe_params_dict, **params["main"]}
            cnmfe_params_dict = CNMFParams(params_dict=tot)
            cnm = cnmf.CNMF(
                n_processes=n_processes, dview=dview, params=cnmfe_params_dict
            )
            print("Performing CNMFE")
            cnm = cnm.fit(images)
            print("evaluating components")
            cnm.estimates.evaluate_components(images, cnm.params, dview=dview)

            cnmf_hdf5_path = output_dir.joinpath(f"{uuid}.hdf5").resolve()
            cnm.save(str(cnmf_hdf5_path))

            # save output paths to outputs dict
            d["cnmf-hdf5-path"] = cnmf_hdf5_path.relative_to(output_dir.parent)

            for proj_type in proj_paths.keys():
                d[f"{proj_type}-projection-path"] = proj_paths[proj_type].relative_to(
                    output_dir.parent
                )

        cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name)
        if IS_WINDOWS:
            Yr._mmap.close()  # accessing private attr but windows is annoying otherwise
        move_file(fname_new, cnmf_memmap_path)

        cnmfe_memmap_path = cnmf_memmap_path.relative_to(output_dir.parent)
        cn_output_path = cn_output_path.relative_to(output_dir.parent)
        pnr_output_path = pnr_output_path.relative_to(output_dir.parent)

        d.update(
            {
                "cnmf-memmap-path": cnmfe_memmap_path,
                "corr-img-path": cn_output_path,
                "pnr-image-path": pnr_output_path,
                "success": True,
                "traceback": None,
            }
        )

    except:
        d = {"success": False, "traceback": traceback.format_exc()}

    cm.stop_server(dview=dview)

    # Add dictionary to output column of series
    df.loc[df["uuid"] == uuid, "outputs"] = [d]
    # Add ran timestamp to ran_time column of series
    df.loc[df["uuid"] == uuid, "ran_time"] = datetime.now().isoformat(timespec="seconds", sep="T")
    df.loc[df["uuid"] == uuid, "algo_duration"] = str(round(time.time() - algo_start, 2)) + " sec"
    # save dataframe to disc
    df.to_pickle(batch_path)


@click.command()
@click.option("--batch-path", type=str)
@click.option("--uuid", type=str)
@click.option("--data-path")
def main(batch_path, uuid, data_path: str = None):
    run_algo(batch_path, uuid, data_path)


if __name__ == "__main__":
    main()
