# module containing the gamma-gamma absorption
from pathlib import Path
import numpy as np
import astropy.units as u
from astropy.io import fits
from astropy.constants import c, G, h, m_e, M_sun, sigma_T
from scipy.interpolate import interp2d
from ..utils.math import axes_reshaper, log
from ..utils.geometry import cos_psi, x_re_shell, mu_star_shell, x_re_ring
from ..utils.conversion import nu_to_epsilon_prime
from ..targets import PointSourceBehindJet, SSDisk, SphericalShellBLR, RingDustTorus


__all__ = ["sigma", "tau_disk_finke_2016", "Absorption", "ebl_files_dict", "EBL"]

agnpy_dir = Path(__file__).parent.parent
ebl_files_dict = {
    "franceschini": f"{agnpy_dir}/data/ebl_models/ebl_franceschini08.fits.gz",
    "dominguez": f"{agnpy_dir}/data/ebl_models/ebl_dominguez11.fits.gz",
    "finke": f"{agnpy_dir}/data/ebl_models/ebl_finke10.fits.gz",
}


def sigma(s):
    """photon-photon pair production cross section, Eq. 17 of [Dermer2009]"""
    beta_cm = np.sqrt(1 - np.power(s, -1))
    prefactor = 3 / 16 * sigma_T * (1 - np.power(beta_cm, 2))
    term1 = (3 - np.power(beta_cm, 4)) * log((1 + beta_cm) / (1 - beta_cm))
    term2 = -2 * beta_cm * (2 - np.power(beta_cm, 2))
    values = prefactor * (term1 + term2)
    values[s < 1] = 0
    return values


def tau_disk_finke_2016(nu, blob, disk, r):
    """Eq. 63 in [Finke2016]_, **for testing purposes only**
    It assumes the blob moves parallel to the jet axis, mu_s = 1 => cos_psi = mu
    """
    r_tilde = (r / disk.R_g).to_value("")
    l_tilde = np.logspace(0, 5, 50) * r_tilde
    epsilon_1 = nu_to_epsilon_prime(nu, blob.z)
    # multidimensional integration
    _R_tilde, _l_tilde, _epsilon_1 = axes_reshaper(disk.R_tilde, l_tilde, epsilon_1)
    _epsilon = disk.epsilon(_R_tilde)
    _phi_disk = disk.phi_disk(_R_tilde)
    _mu = np.sqrt(1 / (1 + _R_tilde ** 2 / _l_tilde ** 2))
    _s = epsilon_1 * _epsilon * (1 - _mu) / 2
    integrand = (
        _l_tilde ** (-2)
        * _R_tilde ** (-5 / 4)
        * _phi_disk
        * (1 + _R_tilde ** 2 / _l_tilde ** 2) ** (-3 / 2)
        * (sigma(_s) / sigma_T).to_value("")
        * (1 - _mu)
    )
    integral_R_tilde = np.trapz(integrand, disk.R_tilde, axis=0)
    integral_l_tilde = np.trapz(integral_R_tilde, l_tilde, axis=0)
    prefactor = 1e7 * disk.l_Edd ** (3 / 4) * disk.M_8 ** (1 / 4) * disk.eta ** (-3 / 4)
    return prefactor * integral_l_tilde


class Absorption:
    """class to compute the absorption due to gamma-gamma pair production

    Parameters
    ----------
    blob : :class:`~agnpy.emission_regions.Blob`
        emission region and electron distribution hitting the photon target
    target : :class:`~agnpy.targets`
        class describing the target photon field
    r : :class:`~astropy.units.Quantity`
        distance of the blob from the Black Hole (i.e. from the target photons)
    """

    def __init__(self, blob, target, r, integrator=np.trapz):
        self.blob = blob
        self.target = target
        self.r = r
        self.set_mu()
        self.set_phi()
        self.set_l()
        self.integrator = integrator

    def set_mu(self, mu_size=100):
        self.mu_size = mu_size
        self.mu = np.linspace(-1, 1, self.mu_size)

    def set_phi(self, phi_size=50):
        self.phi_size = phi_size
        self.phi = np.linspace(0, 2 * np.pi, self.phi_size)

    def set_l(self, l_size=50):
        """set an array of distances over which to integrate
        integrate up to 100 kpc"""
        max_l = 100 * u.kpc
        self.l = (
            np.logspace(
                np.log10(self.r.to_value("cm")), np.log10(max_l.to_value("cm")), l_size
            )
            * u.cm
        )

    def opacity_point_source(self, nu):
        """opacity generated by a point source behind the jet

        Parameters
        ----------
        nu : `~astropy.units.Quantity`
            array of frequencies, in Hz, to compute the sed, **note** these are 
            observed frequencies (observer frame).
        """
        # conversions
        epsilon_1 = nu_to_epsilon_prime(nu, self.blob.z)
        s = self.target.epsilon_0 * epsilon_1 * (1 - self.blob.mu_s) / 2
        integral = (1 - self.blob.mu_s) * sigma(s) / self.r
        prefactor = self.target.L_0 / (4 * np.pi * self.target.epsilon_0 * m_e * c ** 3)
        return (prefactor * integral).to_value("")

    def opacity_disk(self, nu):
        """opacity generated by a Shakura Sunyaev disk
        Parameters
        ----------
        nu : `~astropy.units.Quantity`
            array of frequencies, in Hz, to compute the absorption, **note** 
            these are observed frequencies (observer frame).
        """
        # conversions
        epsilon_1 = nu_to_epsilon_prime(nu, self.blob.z)
        # each value of l, distance from the BH, defines a different range of
        # cosine integration, we cannot use direct integration of multidimensional
        # numpy arrays but we have to use an explicit loop over the distances
        # as mu takes a different array of values at each distance
        l_tilde = np.logspace(0, 5, 50) * (self.r / self.target.R_g).to_value("")
        tau_epsilon_1 = np.ndarray(0)
        for _epsilon_1 in epsilon_1:
            integrand_l_tilde = np.ndarray(0)
            for _l_tilde in l_tilde:
                # multidimensional integration
                mu = self.target.evaluate_mu_from_r_tilde(
                    self.target.R_in_tilde, self.target.R_out_tilde, _l_tilde
                )
                _mu, _phi = axes_reshaper(mu, self.phi)
                epsilon = self.target.epsilon_mu(_mu, _l_tilde)
                phi_disk = self.target.phi_disk_mu(_mu, _l_tilde)
                # angle between the photons
                _cos_psi = cos_psi(self.blob.mu_s, _mu, _phi)
                s = _epsilon_1 * epsilon * (1 - _cos_psi) / 2
                cross_section = sigma(s).to_value("cm2")
                integrand = (
                    phi_disk
                    / epsilon
                    / _mu
                    / np.power(_l_tilde, 3)
                    / np.power(np.power(_mu, -2) - 1, -3 / 2)
                    * (1 - _cos_psi)
                    * cross_section
                )
                # integrate over the solid angle
                integral_mu = np.trapz(integrand, mu, axis=0)
                integral_phi = np.trapz(integral_mu, self.phi, axis=0)
                integrand_l_tilde = np.append(integrand_l_tilde, integral_phi)
            # integrate over the distance
            integral_l_tilde = np.trapz(integrand_l_tilde, l_tilde, axis=0)
            tau_epsilon_1 = np.append(tau_epsilon_1, integral_l_tilde)
        prefactor = (3 * self.target.L_disk) / (
            np.power(4 * np.pi, 2)
            * self.target.eta
            * m_e
            * np.power(c, 3)
            * self.target.R_g
        )
        tau_epsilon_1 *= prefactor.to_value("cm-2")
        return tau_epsilon_1

    def opacity_shell_blr(self, nu):
        """opacity generated by a spherical shell Broad Line Region

        Parameters
        ----------
        nu : `~astropy.units.Quantity`
            array of frequencies, in Hz, to compute the sed, **note** these are
            observed frequencies (observer frame).
        """
        # conversions
        epsilon_1 = nu_to_epsilon_prime(nu, self.blob.z)
        # multidimensional integration
        _mu, _phi, _l, _epsilon_1 = axes_reshaper(self.mu, self.phi, self.l, epsilon_1)
        x = x_re_shell(_mu, self.target.R_line, _l)
        _mu_star = mu_star_shell(_mu, self.target.R_line, _l)
        _cos_psi = cos_psi(self.blob.mu_s, _mu_star, _phi)
        s = _epsilon_1 * self.target.epsilon_line * (1 - _cos_psi) / 2
        integrand = (1 - _cos_psi) / np.power(x, 2) * sigma(s)
        # integrate
        integral_mu = np.trapz(integrand, self.mu, axis=0)
        integral_phi = np.trapz(integral_mu, self.phi, axis=0)
        integral = np.trapz(integral_phi, self.l, axis=0)
        prefactor = (self.target.xi_line * self.target.L_disk) / (
            (4 * np.pi) ** 2 * self.target.epsilon_line * m_e * c ** 3
        )
        return (prefactor * integral).to_value("")

    def opacity_ring_torus(self, nu):
        """opacity generated by a ring Dust Torus

        Parameters
        ----------
        nu : `~astropy.units.Quantity`
            array of frequencies, in Hz, to compute the sed, **note** these are
            observed frequencies (observer frame).
        """
        # conversions
        epsilon_1 = nu_to_epsilon_prime(nu, self.blob.z)
        # multidimensional integration
        _phi, _l, _epsilon_1 = axes_reshaper(self.phi, self.l, epsilon_1)
        x = x_re_ring(self.target.R_dt, _l)
        _mu = _l / x
        _cos_psi = cos_psi(self.blob.mu_s, _mu, _phi)
        s = _epsilon_1 * self.target.epsilon_dt * (1 - _cos_psi) / 2
        integrand = (1 - _cos_psi) / np.power(x, 2) * sigma(s)
        # integrate
        integral_phi = np.trapz(integrand, self.phi, axis=0)
        integral = np.trapz(integral_phi, self.l, axis=0)
        prefactor = (self.target.xi_dt * self.target.L_disk) / (
            8 * np.pi ** 2 * self.target.epsilon_dt * m_e * c ** 3
        )
        return (prefactor * integral).to_value("")

    def tau(self, nu):
        """optical depth

        .. math::
            \\tau_{\\gamma \\gamma}(\\nu)

        Parameters
        ----------
        nu : `~astropy.units.Quantity`
            array of frequencies, in Hz, to compute the opacity, **note** these are
            observed frequencies (observer frame).
        """
        if isinstance(self.target, PointSourceBehindJet):
            return self.opacity_point_source(nu)
        if isinstance(self.target, SSDisk):
            return self.opacity_disk(nu)
        if isinstance(self.target, SphericalShellBLR):
            return self.opacity_shell_blr(nu)
        if isinstance(self.target, RingDustTorus):
            return self.opacity_ring_torus(nu)


class EBL:
    """Class representing for the Extragalactic Background Light absorption. 
    Tabulated values of absorption as a function of redshift and energy according
    to the models of [Franceschini2008]_, [Finke2010]_, [Dominguez2011]_ are available
    in `data/ebl_models`. 
    They are interpolated by `agnpy` and can be later evaluated for a given redshift 
    and range of frequencies.

    Parameters
    ----------
    model : ["franceschini", "dominguez", "finke"]
        choose the reference for the EBL model
    """

    def __init__(self, model="franceschini"):
        if model not in ["franceschini", "dominguez", "finke"]:
            raise ValueError("No EBL model for the reference you specified")
        else:
            self.model_file = ebl_files_dict[model]
        # load the absorption table
        self.load_absorption_table()
        self.interpolate_absorption_table()

    def load_absorption_table(self):
        """load the reference values from the table file to be interpolated later"""
        f = fits.open(self.model_file)
        self.energy_ref = (
            np.sqrt(f["ENERGIES"].data["ENERG_LO"] * f["ENERGIES"].data["ENERG_HI"])
            * u.eV
        )
        # Franceschini file has two columns repeated, eliminate them
        self.z_ref = np.unique(f["SPECTRA"].data["PARAMVAL"])
        self.values_ref = np.unique(f["SPECTRA"].data["INTPSPEC"], axis=0)

    def interpolate_absorption_table(self, kind="linear"):
        """interpolate the reference values, choose the kind of interpolation"""
        log10_energy_ref = np.log10(self.energy_ref.to_value("eV"))
        self.interpolated_model = interp2d(
            log10_energy_ref, self.z_ref, self.values_ref, kind=kind
        )

    def absorption(self, z, nu):
        energy = nu.to_value("eV", equivalencies=u.spectral())
        log10_energy = np.log10(energy)
        return self.interpolated_model(log10_energy, z)
