"""Some utility functions used acrross the repository."""
import threading
from typing import Any, Tuple

import ipywidgets as ipw
import more_itertools as mit
import numpy as np
import traitlets
from aiida.plugins import DataFactory
from ase import Atoms
from ase.io import read

CifData = DataFactory("core.cif")  # pylint: disable=invalid-name
StructureData = DataFactory("core.structure")  # pylint: disable=invalid-name
TrajectoryData = DataFactory("core.array.trajectory")  # pylint: disable=invalid-name


def valid_arguments(arguments, valid_args):
    """Check whether provided arguments are valid."""
    result = {}
    for key, value in arguments.items():
        if key in valid_args:
            if isinstance(value, (tuple, list)):
                result[key] = "\n".join(value)
            else:
                result[key] = value
    return result


def predefine_settings(obj, **kwargs):
    """Specify some pre-defined settings."""
    for key, value in kwargs.items():
        if hasattr(obj, key):
            setattr(obj, key, value)
        else:
            raise AttributeError(f"'{obj}' object has no attribute '{key}'")


def get_ase_from_file(fname, format=None):  # pylint: disable=redefined-builtin
    """Get ASE structure object."""
    # store_tags parameter is useful for CIF files
    # https://wiki.fysik.dtu.dk/ase/ase/io/formatoptions.html#cif
    if format == "cif":
        traj = read(fname, format=format, index=":", store_tags=True)
    else:
        traj = read(fname, format=format, index=":")
    if not traj:
        raise ValueError(f"Could not read any information from the file {fname}")
    return traj


def find_ranges(iterable):
    """Yield range of consecutive numbers."""
    for group in mit.consecutive_groups(iterable):
        group = list(group)
        if len(group) == 1:
            yield group[0]
        else:
            yield group[0], group[-1]


def list_to_string_range(lst, shift=1):
    """Converts a list like [0, 2, 3, 4] into a string like '1 3..5'.

    Shift used when e.g. for a user interface numbering starts from 1 not from 0"""
    return " ".join(
        [
            f"{t[0] + shift}..{t[1] + shift}"
            if isinstance(t, tuple)
            else str(t + shift)
            for t in find_ranges(sorted(lst))
        ]
    )


def string_range_to_list(strng, shift=-1):
    """Converts a string like '1 3..5' into a list like [0, 2, 3, 4].

    Shift used when e.g. for a user interface numbering starts from 1 not from 0"""
    singles = [int(s) + shift for s in strng.split() if s.isdigit()]
    ranges = [r for r in strng.split() if ".." in r]
    if len(singles) + len(ranges) != len(strng.split()):
        return list(), False
    for rng in ranges:
        try:
            start, end = rng.split("..")
            singles += [i + shift for i in range(int(start), int(end) + 1)]
        except ValueError:
            return list(), False
    return singles, True


def get_formula(data_node):
    """A wrapper for getting a molecular formula out of the AiiDA Data node"""
    if isinstance(data_node, TrajectoryData):
        # TrajectoryData can only hold structures with the same chemical formula,
        # so this approach is sound.
        stepid = data_node.get_stepids()[0]
        return data_node.get_step_structure(stepid).get_formula()
    elif isinstance(data_node, StructureData):
        return data_node.get_formula()
    elif isinstance(data_node, CifData):
        return data_node.get_ase().get_chemical_formula()
    else:
        raise ValueError(f"Cannot get formula from node {type(data_node)}")


class PinholeCamera:
    def __init__(self, matrix):
        self.matrix = np.reshape(matrix, (4, 4)).transpose()

    def screen_to_vector(self, move_vector):
        """Converts vector from the screen coordinates to the normalized vector in 3D."""
        move_vector[0] = -move_vector[0]  # the x axis seem to be reverted in nglview.
        res = np.append(np.array(move_vector), [0])
        res = self.inverse_matrix.dot(res)
        res /= np.linalg.norm(res)
        return res[0:3]

    @property
    def inverse_matrix(self):
        return np.linalg.inv(self.matrix)


class _StatusWidgetMixin(traitlets.HasTraits):
    """Show temporary messages for example for status updates.
    This is a mixin class that is meant to be part of an inheritance
    tree of an actual widget with a 'value' traitlet that is used
    to convey a status message. See the non-private classes below
    for examples.
    """

    message = traitlets.Unicode(default_value="", allow_none=True)
    new_line = "\n"

    def __init__(self, clear_after=3, *args, **kwargs):
        self._clear_timer = None
        self._clear_after = clear_after
        self._message_stack = []
        super().__init__(*args, **kwargs)

    def _clear_value(self):
        """Set widget .value to be an empty string."""
        if self._message_stack:
            self._message_stack.pop(0)
            self.value = self.new_line.join(self._message_stack)
        else:
            self.value = ""

    def show_temporary_message(self, value, clear_after=None):
        """Show a temporary message and clear it after the given interval."""
        clear_after = clear_after or self._clear_after
        if value:
            self._message_stack.append(value)
            self.value = self.new_line.join(self._message_stack)

            # Start new timer that will clear the value after the specified interval.
            self._clear_timer = threading.Timer(self._clear_after, self._clear_value)
            self._clear_timer.start()
            self.message = None


class StatusHTML(_StatusWidgetMixin, ipw.HTML):
    """Show temporary HTML messages for example for status updates."""

    new_line = "<br>"

    # This method should be part of _StatusWidgetMixin, but that does not work
    # for an unknown reason.
    @traitlets.observe("message")
    def _observe_message(self, change):
        self.show_temporary_message(change["new"])


def ase2spglib(ase_structure: Atoms) -> Tuple[Any, Any, Any]:
    """
    Convert ase Atoms instance to spglib cell in the format defined at
    https://spglib.github.io/spglib/python-spglib.html#crystal-structure-cell
    """
    lattice = ase_structure.get_cell()
    positions = ase_structure.get_scaled_positions()
    numbers = ase_structure.get_atomic_numbers()

    return (lattice, positions, numbers)
