#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__copyright__ = """ This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Laboratory of Physical Chemistry, Reiher Group.
See LICENSE.txt for details.
"""

# Standard library imports
import numpy as np
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union, Generator
from scipy.sparse.csgraph import connected_components
from scipy.spatial import distance_matrix

# Third party imports
import scine_database as db
import scine_utilities as utils

# Local application imports
from .lebedev_sphere import LebedevSphere
from .unit_circle import UnitCircle
from . import ReactiveComplexes


def assemble_reactive_complex(
    atoms1: utils.AtomCollection,
    atoms2: utils.AtomCollection,
    lhs_list: List[int],
    rhs_list: List[int],
    x_alignment_0: Optional[List[float]] = None,
    x_alignment_1: Optional[List[float]] = None,
    x_rotation: float = 0.0,
    x_spread: float = 2.0,
    displacement: float = 0.0,
):
    """
    Assembles a reactive complex from the parameters generated by the
    InterReactiveComplexes class.

    Parameters
    ----------
    atoms1, atoms2 :: utils.AtomCollection
        The atoms of both structures, that are to be combined in the reactive
        complex. ``atoms1`` refers to the LHS and ``atoms2`` to the RHS.
    lhs_list, rhs_list :: List[int]
        Indices of the reactive sites within the reactive complex.
        The lhs_list should correspond to atoms1 and rhs_list to atoms2.
    x_alignment_0 :: List[float], length=9
        In case of two structures building the reactive complex, this option
        describes a rotation of the first structure (index 0) that aligns
        the reaction coordinate along the x-axis (pointing towards +x).
        The rotation assumes that the geometric mean position of all
        atoms in the reactive site (``lhs_list``) is shifted into the
        origin.
    x_alignment_1 :: List[float], length=9
        In case of two structures building the reactive complex, this option
        describes a rotation of the second structure (index 1) that aligns
        the reaction coordinate along the x-axis (pointing towards -x).
        The rotation assumes that the geometric mean position of all
        atoms in the reactive site (``rhs_list``) is shifted into the
        origin.
    x_rotation :: float
        In case of two structures building the reactive complex, this option
        describes a rotation angle around the x-axis of one of the two
        structures after ``x_alignment_0`` and ``x_alignment_1`` have
        been applied.
    x_spread :: float
        In case of two structures building the reactive complex, this option
        gives the distance by which the two structures are moved apart along
        the x-axis after ``x_alignment_0``, ``x_alignment_1``, and
        ``x_rotation`` have been applied.
    displacement :: float
        In case of two structures building the reactive complex, this option
        adds a random displacement to all atoms (random direction, random
        length). The maximum length of this displacement (per atom) is set to
        be the value of this option.

    Returns
    -------
    utils.AtomCollection
        The reactive complex structure.
    List[int], List[int]
        The LHS and RHS lists with indices adapted to the reactive complex
        structure.
    """
    if x_alignment_0 is None:
        x_alignment_0 = []
    if x_alignment_1 is None:
        x_alignment_1 = []
    elements1 = atoms1.elements
    elements2 = atoms2.elements
    coordinates1 = atoms1.positions
    coordinates2 = atoms2.positions
    # Calculate reactive center mean position
    sites1 = lhs_list
    sites2 = rhs_list
    rhs_list = list(idx + len(elements1) for idx in sites2)
    reactive_center1 = np.mean(coordinates1[sites1], axis=0)
    reactive_center2 = np.mean(coordinates2[sites2], axis=0)
    # Place reactive center mean position into origin
    coord1 = coordinates1 - reactive_center1
    coord2 = coordinates2 - reactive_center2
    # Rotate directions towards each other
    r = np.array(x_alignment_0).reshape((3, 3))
    coord1 = (r.T.dot(coord1.T)).T
    r = np.array(x_alignment_1).reshape((3, 3))
    coord2 = (r.T.dot(coord2.T)).T
    # Rotate around x-axis
    angle = x_rotation
    x_rot = np.array([[1.0, 0.0, 0.0], [0.0, np.cos(angle), -np.sin(angle)], [0.0, np.sin(angle), np.cos(angle)]])
    coord2 = x_rot.dot(coord2.T).T
    # Displace coordinates of the molecules along the x-axis
    coord2 += np.array([x_spread, 0.0, 0.0])
    coord1 -= np.array([x_spread, 0.0, 0.0])
    # Apply small seeded random displacement
    np.random.seed(42)
    coord1 += displacement * (np.random.rand(*coord1.shape) - 0.5) * 2.0 / np.sqrt(3.0)
    coord2 += displacement * (np.random.rand(*coord2.shape) - 0.5) * 2.0 / np.sqrt(3.0)
    start_atoms = utils.AtomCollection(elements1 + elements2, np.concatenate((coord1, coord2), axis=0))
    return start_atoms, lhs_list, rhs_list


class InterReactiveComplexes(ReactiveComplexes):
    """
    Class to generate reactive complexes from two structures.
    """

    class Options:
        """
        The options for the InterReactiveComplexes
        """

        __slots__ = [
            "number_rotamers",
            "number_rotamers_two_on_two",
            "multiple_attack_points",
        ]

        def __init__(self):
            self.number_rotamers = 2
            """
            int
                The number of rotamers to be generated for reactive complexes with
                at least one active center being an atom. (default: 2)
            """
            self.number_rotamers_two_on_two = 1
            """
            int
                The number of rotamers to be generated for reactive complexes with
                both reactive centers being diatomic. (default: 1)
            """
            self.multiple_attack_points = True
            """
            bool
                Whether to consider multiple attack points for each active
                centers involved in intermolecular reactive pairs or just one.
                (default: True)
            """

    def __init__(self):
        super().__init__()
        self.options = self.Options()
        self.__cache = {}

    @staticmethod
    def _rotation_to_vector(to_rotate: np.ndarray, direction: np.ndarray):
        """
        Generates  a rotation matrix to rotate the row vector 'to_rotate' into
        the direction 'direction'.

        Parameters
        ----------
        to_rotate : np.ndarray of shape (1,3)
            Vector to be rotated.
        direction : np.ndarray of shape (1,3)
            Vector to be rotated on.

        Returns
        -------
        np.ndarray of shape (3,3)
            Rotation matrix
        """

        # Normalize input vectors
        to_rotate_n = to_rotate / np.linalg.norm(to_rotate)
        direction_n = direction / np.linalg.norm(direction)

        # Generate rotation matrix
        v = np.cross(direction_n, to_rotate_n)
        c = np.dot(direction_n, to_rotate_n)  # Cosine of angle between vectors
        s = np.linalg.norm(v)  # Sine of angle between vectors
        ident = np.identity(3)

        if s != 0:
            # If to_rotate, direction not (anti-)parallel
            k = np.array([[0.0, -v[2], v[1]], [v[2], 0.0, -v[0]], [-v[1], v[0], 0.0]])
            r = ident + k + k.dot(k) * ((1 - c) / (s ** 2))  # Rotation matrix
        elif s == 0 and c == 1.0:
            # If to_rotate, direction parallel: Do not do anything
            r = ident
        else:
            # If to_rotate, direction antiparallel: Invert direction
            r = -ident
        return r

    @staticmethod
    def _calculate_x_shift(
        coord1: np.ndarray, elem1: List[utils.ElementType], coord2: np.ndarray, elem2: List[utils.ElementType]
    ) -> float:
        """
        If two molecules are too close to each other, calculates a shift to be
        applied along the x-axis in order to separate the two molecules.
        The shift has to be applied to both molecules individually, where the sign is important.
        coord1 must be shifted in - x-direction ([-extra_shift,0,0]), coord2 in + x-direction ([extra_shift,0,0]).

        Parameters
        ----------
        coord1, coord2 : np.ndarray of shape (n,3) and (m,3)
            Atom positions of both molecules. These are modified in place.
        elem1, elem2 : List[utils.ElementType] of length n and m
            Element types of elements

        Return
        ------
        float
            The additional shift required to separate the molecules along the x-axis.
        """

        extra_shift = 0.0
        for e1, p1 in zip(elem1, coord1):
            for e2, p2 in zip(elem2, coord2):
                dist_vec = p2 - p1
                dist = np.linalg.norm(dist_vec)
                min_dist = utils.ElementInfo.vdw_radius(e1) + utils.ElementInfo.vdw_radius(e2)
                if dist < min_dist:
                    # Compute shift s.t. shifting into the x-direction results
                    #  into the required minimum distance
                    new_shift = -1.0 * dist_vec[0] + np.sqrt(dist_vec[0] * dist_vec[0] + min_dist ** 2 - dist ** 2)
                    new_shift = 0.5 * new_shift
                    if extra_shift < new_shift:
                        extra_shift = new_shift
        return extra_shift

    @staticmethod
    def _prune_buried_points(
        indices: Union[Tuple[int], Tuple[int, int]],
        coords: np.ndarray,
        element_types: List[utils.ElementType],
        points: np.ndarray,
        vdw_scaling: float = 0.7,
    ) -> np.ndarray:
        """
        Prunes points from the given list of points centered around the
        centroid of the atoms with indices 'indices'. All points that are
        within the van-der-Waals radius times 'vdw_scaling' of another atom in
        the given list of atoms are removed.

        Parameters
        ----------
        indices : Union[Tuple[int], Tuple[int, int]]
            The indices of the atoms around whose centroid the points are
            centered.
        coords : np.ndarray of shape (n,3)
            Atom positions.
        element_types : List[utils.ElementType] of size n
            Element types of elements.
        points : np.ndarray of shape (x,3)
            A 2D array (matrix) of points centered around the centroid of the
            atoms with the indices 'indices'.
        vdw_scaling : float
            A scaling factor for the vdW radii, a smaller factor will keep more
            points.

        Returns
        -------
        np.ndarray of shape (n,3)
            A 2D array (matrix) of all remaining points.
        """
        # Identify close atoms that are not within indices
        close_atoms = []
        vdw_params = []
        centroid = np.mean(coords[list(indices)], axis=0)
        for idx, (c, e) in enumerate(zip(coords, element_types)):
            if idx in indices:
                continue
            dist = np.linalg.norm(c - centroid)
            if dist < 10.0:
                close_atoms.append(c)
                vdw_params.append(vdw_scaling * utils.ElementInfo.vdw_radius(e))

        # Delete all points within the vdW sphere of other atoms
        remaining = []
        for i, point in enumerate(points):
            for c, v in zip(close_atoms, vdw_params):
                if np.linalg.norm(c - point) < v:
                    break
            else:
                remaining.append(i)
        pruned = np.zeros((len(remaining), 3))
        for i, j in enumerate(remaining):
            pruned[i][0] = points[j][0]
            pruned[i][1] = points[j][1]
            pruned[i][2] = points[j][2]
        return pruned

    @staticmethod
    def _prune_close_attack_points(
        points: np.ndarray, repulsion: np.ndarray, r: float, min_angle: float = 20.0
    ) -> np.ndarray:
        """
        Prunes attack points that are too close to each other.
        Among close points only the minimum repulsion point is kept.

        Parameters
        ----------
        points : np.ndarray
            The attack points, which should already be pruned by repulsion.
        repulsion : np.ndarray
            The repulsion values on the points.
        r : float
            The radius of the sphere the points are located at.
        min_angle : float, optional
            The minimum angle in degrees required between two points such that
            they are considered to be separate, by default 20.

        Returns
        -------
        np.ndarray
            The points remaining after pruning away valleys.
        """

        # Maximum distance for which two points are considered to be part of the same valley
        threshold_dist = 2 * r * np.sin(min_angle * utils.PI / 360.0)
        # Get distances and adjacency matrix
        distances = distance_matrix(points, points)
        adjacency = (distances <= threshold_dist).astype(int)
        # Find connected subgraphs
        n_valleys, valley_labels = connected_components(adjacency, directed=False, return_labels=True)
        # Indices of the points to be kept
        keeper = []
        # Within each group of close point keep the one with minimum repulsion
        for v in range(n_valleys):
            keeper.append(np.argmin(np.ma.masked_where(valley_labels != v, repulsion)))

        valley_pruned = np.empty((len(keeper), 3))
        for i, k in enumerate(keeper):
            valley_pruned[i] = points[k]
        return valley_pruned

    def _prune_by_repulsion(
        self,
        indices: Union[Tuple[int], Tuple[int, int]],
        coords: np.ndarray,
        element_types: List,
        points: np.ndarray,
        nearest_neighbors,
        radius: float,
        min_angle_distance: float = 20.0,
    ) -> np.ndarray:
        """
        Prunes points from the given list of points centered around the centroid of the atoms with
        the indices 'indices'. All points (:math:`\\{i\\}`) will be assigned a repulsion
        value based on their distance to all other atoms :math:`N \\setminus \\{N_i\\}`
        with :math:`N_i` being the atom point :math:`i` is centered on.

        .. math:: C_i = \\sum^{N \\setminus N_i}_{I} \\frac{1}{||\\vec{r_i} - \\vec{R_I}||^6}

        The point with the minimal repulsion (largest average distance to all other
        atoms) survives. If `self.options.multiple_attack_points' is `True' all locally best points
        (i.e. better or equal than all of their nearest neighbors) survive.
        Otherwise only the single best point of attack is generated.
        A cutoff of 15 a.u. around :math:`N_i` is used to prescreen all atoms
        included in :math:`N`.

        If for one atom several attack points have an angular distance of less
        than 'min_angle_distance' only the one with the minimum repulsion is kept.

        Parameters
        ----------
        indices : Tuple[int]
            The indices of the atoms around whose centroid the points are centered.
        coords : np.ndarray of shape (n,3)
            Atom positions.
        element_types : List[utils.ElementType] of size n
            Element types of elements.
        points : np.ndarray of shape (x,3)
            A 2D array (matrix) of points centered around the atom with the index
            'index'.
        nearest_neighbors : List[List[int]] of length x
            The indices of the nearest neighbors of all points.
        radius : float
            The radius of the sphere or circle the points are located on.
        min_angle_distance : float
            The minimum angle in degrees required between two attack points such
            that both of them are kept, by default 20.

        Returns
        -------
        np.ndarray of shape (1,3)
            A single point of attack. The array will be empty (shape (0,0)),
            if no point is left/viable.
        """
        if len(points) == 0:
            return np.zeros((0, 0))

        # Identify close atoms
        close_atoms = []
        vdw_params = []
        centroid = np.mean(coords[list(indices)], axis=0)
        for idx, (c, e) in enumerate(zip(coords, element_types)):
            if idx in indices:
                continue
            dist = np.linalg.norm(c - centroid)
            if dist < 15.0:
                close_atoms.append(c)
                vdw_params.append(utils.ElementInfo.vdw_radius(e))

        # Calculate Repulsion
        repulsion = []
        for p in points:
            r: float = 0.0
            for c in close_atoms:
                dist = np.linalg.norm(c - p)
                if dist <= 0.0:
                    r += float("inf")
                else:
                    d2 = dist * dist
                    r += float(1.0 / (d2 * d2 * d2))
            repulsion.append(r)

        if self.options.multiple_attack_points:
            keepers = []
            for i, nn in enumerate(nearest_neighbors):
                if np.isinf(repulsion[i]):
                    continue
                for n in nn:
                    if repulsion[n] < repulsion[i]:
                        break
                else:
                    keepers.append(i)

            repulsion_pruned = np.zeros((len(keepers), 3))
            pruned_repulsion = []
            for i, k in enumerate(keepers):
                repulsion_pruned[i] = points[k]
                pruned_repulsion.append(repulsion[k])
            # Rule out valleys of many close attack points
            pruned = self._prune_close_attack_points(
                repulsion_pruned, np.asarray(pruned_repulsion), radius, min_angle=min_angle_distance
            )

            if pruned.size > 0:
                return pruned

        # Grab single best point
        pruned = np.zeros((1, 3))
        min_idx = repulsion.index(min(repulsion))
        pruned[0] = points[min_idx]
        return pruned

    def _get_attack_points_per_atom(
        self, coords: np.ndarray, element_types: List, vdw_scaling: float = 0.7, indices: Union[List[int], None] = None
    ) -> Dict[Tuple[int], np.ndarray]:
        """
        Generates the viable attack points for each atom in the given molecule.
        If `self.options.multiple_attack_points` is `True` multiple attack
        points per atom can be generated, otherwise only one.

        Parameters
        ----------
        coords : np.ndarray of shape (n,3)
            Atom positions.
        element_types :: List[utils.ElementType] of size n
            Element types of elements.
        vdw_scaling : float
            A scaling factor for the vdW radii, a smaller factor generates more
            attack points as there will be less pruning of possible attack
            directions based on neighbouring atoms. (default: 0.7)
        indices : Union[List[int], None]
            A list of atom indices.
            If given, only generates attack points for the referenced atoms

        Returns
        -------
        Dict[Tuple[int], np.ndarray]]
            One np.ndarray of shape (n,3) per atom for which there is at least
            one attack point and index of that atom
        """

        lebedev = LebedevSphere()
        initial_points = lebedev.points
        nearest_neighbors = lebedev.nearest_neighbors

        per_atom = {}
        for i, e in enumerate(element_types):
            if indices is not None and i not in indices:
                continue
            possible_directions = deepcopy(initial_points) * utils.ElementInfo.vdw_radius(e)
            possible_directions += coords[i]
            possible_directions = self._prune_by_repulsion(
                (i,), coords, element_types, possible_directions, nearest_neighbors, utils.ElementInfo.vdw_radius(e)
            )
            possible_directions = self._prune_buried_points(
                (i,), coords, element_types, possible_directions, vdw_scaling=vdw_scaling
            )
            if not possible_directions.size == 0:
                if (i,) in per_atom:
                    raise RuntimeError("Requested attack points for the same atom twice.")
                per_atom[(i,)] = possible_directions
        return per_atom

    def _get_attack_points_per_atom_pair(
        self, coords: np.ndarray, element_types: List, valid_pairs: List[Tuple[int, ...]], vdw_scaling: float = 0.7
    ) -> Dict[Tuple[int, int], np.ndarray]:
        """
        Generates the viable attack point for each atom pair in the given
        molecule with a pair distance smaller than
        `self.options.max_graph_distance`. If
        `self.options.multiple_attack_points` is `True` multiple attack points
        per atom pair can be generated, otherwise only one.

        Parameters
        ----------
        coords : np.ndarray of shape (n,3)
            Atom positions.
        element_types :: List[utils.ElementType] of size n
            Element types of elements.
        valid_pairs : List[Tuple[int, ...]]
            The list of valid atom pairs to generate attack points for.
        vdw_scaling : float
            A scaling factor for the vdW radii, a smaller factor generates more
            attack points as there will be less pruning of possible attack
            directions based on neighbouring atoms.  (default: 0.7)

        Returns
        -------
        List[Tuple(Tuple(int, int), np.ndarray))]
            A list of Tuples containing the indices representing atom pairs and
            the np.ndarray of shape (n,3) containing the attack points found
            between this atom pair.
        """
        unit_circle = UnitCircle()
        initial_points = np.append(unit_circle.points, np.zeros((100, 1)), axis=1)
        nearest_neighbors = unit_circle.nearest_neighbors

        per_atom_pair = {}

        for i, j in valid_pairs:
            e_i = element_types[i]
            e_j = element_types[j]
            circle_radius = 0.5 * (utils.ElementInfo.vdw_radius(e_i) + utils.ElementInfo.vdw_radius(e_j))
            possible_directions = deepcopy(initial_points) * circle_radius
            # Rotate circle s.t. its normal aligns with the interatom axis
            interatom = coords[j] - coords[i]  # Interatom axis
            circle_normal = np.array([0.0, 0.0, 1.0])  # Initial points are in xy-plane
            r = self._rotation_to_vector(circle_normal, interatom)
            possible_directions = (r.T.dot(possible_directions.T)).T
            # Move exactly between atoms
            possible_directions += 0.5 * (coords[i] + coords[j])
            # Prune
            possible_directions = self._prune_by_repulsion(
                (i, j), coords, element_types, possible_directions, nearest_neighbors, circle_radius
            )
            possible_directions = self._prune_buried_points(
                (i, j), coords, element_types, possible_directions, vdw_scaling=vdw_scaling
            )
            # Since indices are stored empty entries are not necessary
            if not possible_directions.size == 0:
                if (i, j) in per_atom_pair:
                    raise RuntimeError("Requested attack points for the same atom pair twice.")
                per_atom_pair[(i, j)] = possible_directions

        return per_atom_pair

    def _set_up_rotamers(
        self,
        coordinates1: np.ndarray,
        elem1: List[utils.ElementType],
        sites1: List[int],
        attack_points1: np.ndarray,
        coordinates2: np.ndarray,
        elem2: List[utils.ElementType],
        sites2: List[int],
        attack_points2: np.ndarray,
    ) -> List[Tuple[np.ndarray, np.ndarray, float, float]]:
        """
        Returns the operations to align the given molecules such that
        'centroid(sites1)'--'attack_point1'--'attack_point2'--centroid(sites2)
        are aligned along the x-axis. Then generates rotamers:

        - If sites1 and/or sites2 only contains one index, i.e., at least one
          reaction center is an atom, rotates the second molecule in
          equidistant steps to generate 'self.options.number_rotamers' rotamers.
        - If sites1 and sites2 each contain two indices they are first aligned
          such that sites1[0] faces sites2[0] and sites1[1] faces sites2[1].
          If 'self.number_rotamers_two_on_two' is larger than one further rotamers
          are generated by rotating the structures within plus/minus 90 degrees.

        The resulting rotamers are shifted along the x-axis if any of their
        atoms are too close to each other.

        Parameters
        ----------
        coordinates1, coordinates2 : np.ndarray of shape (n,3) and (m,3)
            Atom positions of both molecules
        elem1, elem2 : List[utils.ElementType] of size n and m
            Element types of elements of both molecules.
        sites1, sites2 : List[int] of size 1 or 2
            Indices defining the reactive sites: The reactive site is the mean
            position of the atoms with these indices.
        attack_points1, attack_points2 :: np.ndarray of shape (a,3) and (b,3)
            Points of attack for the two sites given (a for site1 and b for
            site2).

        Returns
        -------
        rotamer_operations: List[Tuple(alignment1, alignment2, angle, total_spread)]
            A list of Tuples, each Tuple containing the operations to set up one rotamer.

            alignment1, alignment2: np.array, np.array
                Rotation matrices aligning the two sites along the x-axis (rotations
                assume that the geometric mean of each site is translated into the
                origin)
            angle: float
                Angle of rotation around the x-axis
            total_spread: float
                Spread to be applied along the x-axis between the two structures.
        """

        # Prepare output variables
        tuple_list = []
        # Get average coordinate and vdW radii
        reactive_center1 = np.mean(coordinates1[sites1], axis=0)
        reactive_center2 = np.mean(coordinates2[sites2], axis=0)
        vdw_average: float = np.mean(  # type: ignore
            [utils.ElementInfo.vdw_radius(el) for el in list(
                np.array(elem1)[sites1]) + list(np.array(elem2)[sites2])]
        )
        # Loop over all attack points
        for attack_point1 in attack_points1:
            for attack_point2 in attack_points2:
                # Translate both reactive centers into the origin
                coord1 = coordinates1 - reactive_center1
                coord2 = coordinates2 - reactive_center2
                # Information about sites
                len_sites = [
                    len(sites1),
                    len(sites2),
                ]  # The type of attack site on each reactant
                len_coords = np.array([len(coord1), len(coord2)])  # Number of atoms per reactant
                n_elements = [
                    len(set(elem1)),
                    len(set(elem2)),
                ]  # Number of distinct elements per reactant
                # Generate the relative attack direction from the attack points
                direction1 = attack_point1 - reactive_center1
                direction2 = attack_point2 - reactive_center2
                # Rotate directions towards each other
                # Face S1 in direction +x
                x = np.array([1.0, 0.0, 0.0])
                r = self._rotation_to_vector(direction1, x)
                alignment1 = r.flatten()
                coord1 = (r.T.dot(coord1.T)).T
                #  Face S2 in direction -x
                r = self._rotation_to_vector(direction2, -1.0 * x)
                coord2 = (r.T.dot(coord2.T)).T
                # Parallelize if both centers are defined by two atoms
                if all(n_site == 2 for n_site in len_sites):
                    # Generate two vectors defined by reactive sites
                    interatom1 = coord1[sites1[0]] - coord1[sites1[1]]
                    interatom2 = coord2[sites2[0]] - coord2[sites2[1]]
                    # Rotate interatom2 to align with interatom1
                    r_parallel = self._rotation_to_vector(interatom2, interatom1)
                    coord2 = (r_parallel.T.dot(coord2.T)).T
                    # Combine two rotation operations
                    r = r.dot(r_parallel)
                    # Rotate a bit around x - axis to distort symmetry
                    angle = 0.1
                    r_breaksymm = np.array(
                        [[1.0, 0.0, 0.0], [0.0, np.cos(angle), -np.sin(angle)],
                         [0.0, np.sin(angle), np.cos(angle)]]
                    )
                    coord2 = (r_breaksymm.dot(coord2.T)).T
                    # Combine with previous rotation transformation
                    # Transpose of r_breaksymm must be used
                    r = r.dot(r_breaksymm.T)

                # Store all rotations of coord2 in alignment2
                alignment2 = r.flatten()

                # Displace coordinates of the molecules along the x-axis
                total_spread = vdw_average
                coord2 += np.array([vdw_average, 0.0, 0.0])
                coord1 -= np.array([vdw_average, 0.0, 0.0])
                # Check if none of the atoms are too close and shift eventually
                extra_shift = self._calculate_x_shift(coord1, elem1, coord2, elem2)
                coord2 += np.array([extra_shift, 0.0, 0.0])
                coord1 -= np.array([extra_shift, 0.0, 0.0])
                total_spread += extra_shift
                # # Check whether rotamers shall be generated
                # TODO Take symmetry into account
                generate_rotamers = True
                # Get all indices molecules with less than 3 atoms
                small_mols = np.where(len_coords < 3)[0]

                if len(small_mols) > 0:
                    if any(len_sites[i] < 2 for i in small_mols):
                        # For all reactants with < 3 atoms the site has to be of length 2
                        # Otherwise they are either monoatomic or the site is
                        #  on the axis of a diatomic molecule
                        generate_rotamers = False
                    # TODO Check this
                    elif any(n_elements[i] == 1 for i in small_mols) and self.options.number_rotamers == 2:
                        # If any of the reactants is diatomic with twice the same element and
                        # if the set number of rotamers is 2 there will be a 180 degrees rotation.
                        # It is superfluous for homoatomic diatomics
                        generate_rotamers = False

                # # Prepare and do actual rotation
                if generate_rotamers:
                    # Set-up  'n_rotamers' rotamers
                    if all(n_site == 2 for n_site in len_sites):
                        if self.options.number_rotamers_two_on_two == 1:
                            angles = [0.0]
                        else:
                            # If both sites are diatomic restrict rotamers to half circle where
                            #  sites are within +-90 degrees of each other
                            # Endpoint logic ensures that angle zero is always included and
                            # 90 removed instead in case of even number of rotamers requested
                            angles = list(np.linspace(
                                -np.pi / 2,
                                np.pi / 2.0,
                                num=self.options.number_rotamers_two_on_two,
                                endpoint=bool(self.options.number_rotamers_two_on_two % 2),
                            ))

                    else:
                        # range(0, 2*np.pi) 2*np.pi / n_rotamers
                        angles = list(np.linspace(0.0, 2 * np.pi, num=self.options.number_rotamers, endpoint=False))
                    for angle in angles:
                        x_rot = np.array(
                            [[1.0, 0.0, 0.0], [0.0, np.cos(angle), -np.sin(angle)],
                             [0.0, np.sin(angle), np.cos(angle)]]
                        )
                        coord4 = x_rot.dot(coord2.T).T
                        # Check again if none of the atoms are too close and shift eventually
                        extra_shift = self._calculate_x_shift(coord1, elem1, coord4, elem2)
                        tuple_list.append((alignment1, alignment2, angle, total_spread + extra_shift))
                else:
                    tuple_list.append((alignment1, alignment2, 0.0, total_spread))
        return tuple_list

    def generate_reactive_complexes(
        self, structure1: db.Structure, structure2: db.Structure, reactive_inter_coords: List[List[Tuple[int, int]]]
    ) -> Generator[Tuple[List[Tuple[int, int]], np.ndarray, np.ndarray, float, float], None, None]:
        """
        Generates a set of reactive complexes for two given structures arising from
        the given intermolecular reactive pairs.

        Parameters
        ----------
        structure1, structure2 :: scine_database.Structure (Scine::Database::Structure)
            The two structures for which a set of reactive complexes is to be
            generated. The structures have to be linked to a collection.
        reactive_inter_coords : List[List[Tuple[int, int]]]
            A list of intermolecular reactive atom pairs corresponding to one
            trial reaction coordinate. Each reactive pair tuple has to be
            ordered such that its first element belongs to structure1
            and the second to structure2.
            The indices are expected to refer to be on structure level,
            i.e. the first atom of structure2 has index 0 and not index
            n_atoms(structure1).

        Yields
        ------
        inter_coord : Tuple[Tuple[Tuple[int]]
            Tuple of Tuples of one or two atom pairs composing the reactive atoms
            of the interstructural component of this reactive complex reaction.
            First atom per pair belongs to structure1, second to structure2.
        align1, align2 : np.array
            Rotation matrices aligning the two sites along the x-axis (rotations
            assume that the geometric mean of the reactive atoms of each
            structure is translated into the origin)
        xrot : float
            Angle of rotation around the x-axis
        spread : float
            Spread to be applied along the x-axis between the two structures.
        """
        # Get structure one data
        atoms1 = structure1.get_atoms()
        coordinates1 = atoms1.positions
        elements1 = atoms1.elements
        id1 = str(structure1.get_id())

        # Get structure two data
        atoms2 = structure2.get_atoms()
        coordinates2 = atoms2.positions
        elements2 = atoms2.elements
        id2 = str(structure1.get_id())

        # Get all attack points needed
        # A dictionary with the attack points stored for all relevant atoms and
        #  atom pairs with the indices being the keys

        # Load cache if possible
        if id1 in self.__cache:
            attack_points1 = self.__cache[id1]["points"]
            attacked_atoms1 = self.__cache[id1]["atoms"]
            attacked_pairs1 = self.__cache[id1]["pairs"]
        else:
            attack_points1 = {}
            attacked_atoms1 = set()
            attacked_pairs1 = set()
        if id2 in self.__cache:
            attack_points2 = self.__cache[id2]["points"]
            attacked_atoms2 = self.__cache[id2]["atoms"]
            attacked_pairs2 = self.__cache[id2]["pairs"]
        else:
            attack_points2 = {}
            attacked_atoms2 = set()
            attacked_pairs2 = set()

        new_attacked_atoms1 = set()
        new_attacked_pairs1 = set()
        new_attacked_atoms2 = set()
        new_attacked_pairs2 = set()
        for coord in reactive_inter_coords:

            if len(coord) > 2:
                raise RuntimeError("More than two interstructural coordinates are not supported")

            elif len(coord) == 1:
                # If one atom pair only, then the sites on both structures are monoatomic
                if coord[0][0] not in attacked_atoms1:
                    new_attacked_atoms1.add(coord[0][0])
                if coord[0][1] not in attacked_atoms2:
                    new_attacked_atoms2.add(coord[0][1])
            elif len(coord) == 2:
                # Get unique reactive atoms per structure to check whether twice the same atom or distinct atom pair
                struct1_sites = set(pair[0] for pair in coord)
                struct2_sites = set(pair[1] for pair in coord)
                if len(struct1_sites) == 1 and not struct1_sites.issubset(attacked_atoms1):
                    new_attacked_atoms1.update(struct1_sites)
                elif len(struct1_sites) == 2 and not tuple(sorted(struct1_sites)) in attacked_pairs1:
                    new_attacked_pairs1.add(tuple(sorted(struct1_sites)))
                elif len(struct1_sites) > 2:
                    # Should not be reachable
                    raise RuntimeError(
                        "More than two atoms per structure involved in "
                        + "interstructural reaction coordinates are not supported"
                    )
                if len(struct2_sites) == 1 and not struct2_sites.issubset(attacked_atoms2):
                    new_attacked_atoms2.update(struct2_sites)
                elif len(struct2_sites) == 2 and not tuple(sorted(struct2_sites)) in attacked_pairs2:
                    new_attacked_pairs2.add(tuple(sorted(struct2_sites)))
                elif len(struct2_sites) > 2:
                    # Should not be reachable
                    raise RuntimeError(
                        "More than two atoms per structure involved in interstructural "
                        + "reaction coordinates are not supported"
                    )

        # Generate attack points around atoms
        attack_points1.update(self._get_attack_points_per_atom(
            coordinates1, elements1, indices=list(new_attacked_atoms1)))
        attack_points2.update(self._get_attack_points_per_atom(
            coordinates2, elements2, indices=list(new_attacked_atoms2)))
        attack_points1.update(self._get_attack_points_per_atom_pair(
            coordinates1, elements1, list(new_attacked_pairs1)))
        attack_points2.update(self._get_attack_points_per_atom_pair(
            coordinates2, elements2, list(new_attacked_pairs2)))

        # Update cache
        attacked_atoms1.update(new_attacked_atoms1)
        attacked_pairs1.update(new_attacked_pairs1)
        attacked_atoms2.update(new_attacked_atoms2)
        attacked_pairs2.update(new_attacked_pairs2)
        self.__cache = {
            id1: {
                "points": attack_points1,
                "atoms": attacked_atoms1,
                "pairs": attacked_pairs1,
            },
            id2: {
                "points": attack_points2,
                "atoms": attacked_atoms2,
                "pairs": attacked_pairs2,
            }
        }

        # Generate requested complexes
        for coord in reactive_inter_coords:
            # Get reactive fragments from coordinates without duplicate atoms
            # Do not use set bc order relevant for rotamer generation
            sites1 = []
            sites2 = []
            for pair in coord:
                if pair[0] not in sites1:
                    sites1.append(pair[0])
                if pair[1] not in sites2:
                    sites2.append(pair[1])

            # Get matching attack points
            # Do not sort sites in place because the order is relevant for
            #  the alignment in the rotamer generation
            sorted_sites1 = tuple(sorted(sites1))
            sorted_sites2 = tuple(sorted(sites2))
            # If all attack points are buried for one of the reactants there is nothing to set up
            if sorted_sites1 not in attack_points1 or sorted_sites2 not in attack_points2:
                continue
            p_sites1 = attack_points1[sorted_sites1]
            p_sites2 = attack_points2[sorted_sites2]

            results = self._set_up_rotamers(
                coordinates1, elements1, sites1, p_sites1, coordinates2, elements2, sites2, p_sites2
            )
            for alignI, alignJ, xrot, spread in results:
                yield coord, alignI, alignJ, xrot, spread
