import numpy as np
import pytest

from probnum.diffeq import ode
from probnum.diffeq.odefiltsmooth import KalmanODESolution, probsolve_ivp
from probnum.randvars import Constant


@pytest.fixture
def ivp():
    initrv = Constant(20.0 * np.ones(2))
    return ode.lotkavolterra([0.0, 0.25], initrv)


@pytest.mark.parametrize("method", ["EK0", "EK1"])
@pytest.mark.parametrize(
    "algo_order",
    [1, 2, 5],
)
@pytest.mark.parametrize("dense_output", [True, False])
@pytest.mark.parametrize("step", [0.01, None])
@pytest.mark.parametrize("diffusion_model", ["constant", "dynamic"])
@pytest.mark.parametrize("tolerance", [0.1, np.array([0.09, 0.10])])
def test_adaptive_solver_successful(
    ivp, method, algo_order, dense_output, step, diffusion_model, tolerance
):
    """The solver terminates successfully for all sorts of parametrizations."""
    f = ivp.rhs
    df = ivp.jacobian
    t0, tmax = ivp.timespan
    y0 = ivp.initrv.mean

    sol = probsolve_ivp(
        f,
        t0,
        tmax,
        y0,
        df=df,
        adaptive=True,
        atol=tolerance,
        rtol=tolerance,
        algo_order=algo_order,
        method=method,
        dense_output=dense_output,
        step=step,
    )
    # Successful return value as documented
    assert isinstance(sol, KalmanODESolution)

    # Adaptive steps are not evenly distributed
    step_diff = np.diff(sol.locations)
    step_ratio = np.amin(step_diff) / np.amax(step_diff)
    assert step_ratio < 0.5


def test_wrong_method_raises_error(ivp):
    """Methods that are not in the list raise errors."""
    f = ivp.rhs
    t0, tmax = ivp.timespan
    y0 = ivp.initrv.mean

    # UK1 does not exist anymore
    with pytest.raises(ValueError):
        probsolve_ivp(f, t0, tmax, y0, method="UK")


def test_no_step_or_tol_info_raises_error(ivp):
    """Providing neither a step-size nor a tolerance raises an error."""
    f = ivp.rhs
    t0, tmax = ivp.timespan
    y0 = ivp.initrv.mean

    with pytest.raises(ValueError):
        probsolve_ivp(f, t0, tmax, y0, step=None, adaptive=True, atol=None, rtol=None)


def test_wrong_diffusion_raises_error(ivp):
    """Methods that are not in the list raise errors."""
    f = ivp.rhs
    t0, tmax = ivp.timespan
    y0 = ivp.initrv.mean

    # UK1 does not exist anymore
    with pytest.raises(ValueError):
        probsolve_ivp(f, t0, tmax, y0, diffusion_model="something_wrong")
