# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/


__authors__ = ["A.Mirone"]
__license__ = "MIT"
__date__ = "7/12/2022"

import logging
import numpy
import nxtomomill.nexus
from tomoscan.esrf.scan.hdf5scan import ImageKey, HDF5TomoScan
from tomoscan.esrf.scan.utils import cwd_context
from tomoscan.framereducerbase import ReduceMethod

from typing import Optional


from ..utils.utils import strip_extension

import h5py
import os

logging.basicConfig(level=logging.INFO)

_logger = logging.getLogger(__name__)


def extract_darks_flats(
    dataset_file_name: str,
    entry_name: str,
    save_intermediated: bool = False,
    target_filename: Optional[str] = None,
    target_entry_name: Optional[str] = None,
    method: str = "median",
    reuse_intermediated: bool = False,
    use_projections_for_flats: bool = False,
):
    target_entry_name = target_entry_name if target_entry_name else entry_name

    dirname = os.path.dirname(dataset_file_name)

    basename = os.path.basename(dataset_file_name)

    if not dirname:
        dirname = "./"

    with cwd_context(dirname):
        if reuse_intermediated:
            scan = HDF5TomoScan(target_filename, target_entry_name)
            reduced_flats, metadata_flats = scan.load_reduced_flats(return_info=True)
            reduced_darks, metadata_darks = scan.load_reduced_darks(return_info=True)
        else:
            nxt = nxtomomill.nexus.NXtomo()
            nxt.load(basename, data_path=entry_name)
            if use_projections_for_flats:
                where_proj = [k.value == 0 for k in nxt.instrument.detector.image_key]
                where_flat = [k.value == 1 for k in nxt.instrument.detector.image_key]

                nxt.instrument.detector.image_key_control[
                    where_proj
                ] = ImageKey.FLAT_FIELD
                nxt.instrument.detector.image_key_control[where_flat] = ImageKey.INVALID

                file_path = f"{basename}_edited_keys_scan.nx"
                if os.path.isfile(file_path):
                    os.remove(file_path)
                nxt.save(file_path, entry_name)

                scan = HDF5TomoScan(file_path, entry_name)
                reduced_flats, metadata_flats = scan.compute_reduced_flats(
                    method, return_info=True
                )
                reduced_darks, metadata_darks = scan.compute_reduced_darks(
                    return_info=True
                )
            else:
                scan = HDF5TomoScan(basename, entry_name)
                reduced_flats, metadata_flats = scan.compute_reduced_flats(
                    method, return_info=True
                )
                reduced_darks, metadata_darks = scan.compute_reduced_darks(
                    return_info=True
                )

        if save_intermediated:
            scan = HDF5TomoScan(target_filename, target_entry_name)
            scan.save_reduced_flats(reduced_flats, flats_infos=metadata_flats)
            scan.save_reduced_darks(reduced_darks, darks_infos=metadata_darks)

    return_dict = {
        "flat": {"images": reduced_flats, "meta": metadata_flats},
        "dark": {"images": reduced_darks, "meta": metadata_darks},
    }

    return __RefsDarks(return_dict, entry_name), return_dict


class __RefsDarks:
    def __init__(self, dict_or_file_name, entry_name):
        self.dict_or_file_name = dict_or_file_name
        self.entry_name = entry_name
        self.flat_image, self.flat_current = self._take_image_and_meta("flat")
        self.dark_image, self.dark_current = self._take_image_and_meta("dark")

    def _take_image_and_meta(self, what) -> tuple:
        """
        :return: a tuple as (image, Optional[current:float])
        :rtype: tuple
        """
        if isinstance(self.dict_or_file_name, dict):
            group = self.dict_or_file_name[what]  # [self.entry_name]
            image = None
            for key in group["images"]:
                if isinstance(key, int) or key.isnumeric():
                    if image is None:
                        image = group["images"][key]
                    else:
                        _logger.warning(" more than one image found ")
            if len(group["meta"].machine_electric_current) > 0:
                current = group["meta"].machine_electric_current[0]
            else:
                current = None

        else:
            file_name_tmp = f"{strip_extension(self.dict_or_file_name)}_{what}.h5"
            with h5py.File(file_name_tmp, "r") as f:
                group = f[self.entry_name]
                group = f[what]
                image = None
                current = group["machine_electric_current"][()][0]
                for key in group:
                    if key.isnumeric():
                        if image is None:
                            image = group[key][()]
                        else:
                            raise ValueError(
                                f" more than one image found in {file_name_tmp}"
                            )

        return image, current


def flat_reducer(
    scan_filename: str,
    ref_start_filename: str,
    ref_end_filename: str,
    mixing_factor: float,
    entry_name: str = "entry0000",
    median_or_mean: str = ReduceMethod.MEAN.value,
    save_intermediated: bool = False,
    reuse_intermediated: bool = False,
):
    """
    this method extract a flatfield and dark from  two  reference scans. After flats and darks extraction, an interpolation is done
    according to the mixing_factor parameter. The obtained falts and dark are saved associating them to scan_filename

    :param str scan_filename:
    The target scan. A nexus filename for which we want to create reduced scan from the scans
    given by ref_start and ref_end parameters ( a scan at the beginning, another at the end)

    :param str ref_start_filename:
           The scan with projections to be used as reference for the beginning of the measures.

    :param str ref_end_filename:
          The scan with projections to be used as reference at the end  of the measures.

    :param float mixing_factor:
          The mixing factor giving the averaged flats as (ref_start-darkB+darkS)*(1-mixing_factor)+(ref_end-darkE+darkS)*mixing_factor

    :param str entry_name:
         The entry name, it defaults to entry0000

    :param str median_or_mean:
         Either "mean" or "median". Default is "mean"

    :param bool save_intermediated:
         Save intermediated flats and darks corresponding to extremal reference scans (ref_start_filename, refa_filename)
         for later usage. Defaults to False

    :param bool use_intermediated:
         Save  intermediated flats and darks and if already presente reuse them for mixing"
    """

    if reuse_intermediated:
        required_files = [
            f"{strip_extension(ref_start_filename, _logger)}_darks.hdf5",
            f"{strip_extension(ref_start_filename, _logger)}_flats.hdf5",
            f"{strip_extension(ref_end_filename, _logger)}_darks.hdf5",
            f"{strip_extension(ref_end_filename, _logger)}_flats.hdf5",
        ]
        intermediated_are_reusable = True
        for fn in required_files:
            if not os.path.exists(fn):
                intermediated_are_reusable = False
    else:
        intermediated_are_reusable = False

    # saving the intermediae if enforced if there is a plan to use them
    # and they are not available yet
    save_intermediated = save_intermediated or (
        reuse_intermediated and not intermediated_are_reusable
    )

    if median_or_mean not in [ReduceMethod.MEAN.value, ReduceMethod.MEDIAN.value]:
        message = f""" the "median_or_mean" parameter must be one of {[ReduceMethod.MEAN.value, ReduceMethod.MEDIAN.value]}.
        It was {median_or_mean}
        """
        raise ValueError(message)

    fd_start, _ = extract_darks_flats(
        ref_start_filename,
        entry_name,
        target_filename=ref_start_filename,
        save_intermediated=save_intermediated,
        method=median_or_mean,
        reuse_intermediated=intermediated_are_reusable,
        use_projections_for_flats=True,
    )

    fd_end, _ = extract_darks_flats(
        ref_end_filename,
        entry_name,
        target_filename=ref_end_filename,
        save_intermediated=save_intermediated,
        method=median_or_mean,
        reuse_intermediated=intermediated_are_reusable,
        use_projections_for_flats=True,
    )
    fd_sample, fd_as_dict = extract_darks_flats(
        scan_filename,
        entry_name,
        method=median_or_mean,
        use_projections_for_flats=False,
    )
    reduced_infos = fd_as_dict["flat"]["meta"]

    scan = HDF5TomoScan(scan_filename, entry_name)
    current = fd_sample.flat_current
    if current is None:
        # handle the case the fd_sample does not contains any flat frames. In this case get the first
        # current we find from the NXtomo
        currents = scan.electric_current
        if currents is not None and len(currents) > 0:
            current = currents[0]  # pylint: disable=E1136

    if current is None:
        raise ValueError(
            f"Unable to find any machine electric current from {scan_filename}. Unable to compute reduced darks and flats"
        )

    # compute reduced flats and dark
    flat0 = (
        fd_start.flat_image - fd_start.dark_image
    ) * current / fd_start.flat_current + fd_sample.dark_image
    flat1 = (
        fd_end.flat_image - fd_end.dark_image
    ) * current / fd_end.flat_current + fd_sample.dark_image

    flat = (1 - mixing_factor) * flat0 + mixing_factor * flat1

    reduced_flats = {0: flat}

    # save reduced flats and dark
    reduced_infos.machine_electric_current = numpy.array([current])
    reduced_infos.count_time = reduced_infos.count_time[:1]
    if current != reduced_infos.machine_electric_current[0]:
        raise RuntimeError(
            " Coherence check failed. Total non sense: the code is broken."
        )

    scan.save_reduced_flats(reduced_flats, flats_infos=reduced_infos)

    reduced_darks = fd_as_dict["dark"]["images"]
    reduced_infos = fd_as_dict["dark"]["meta"]
    scan.save_reduced_darks(reduced_darks, darks_infos=reduced_infos)
