from __future__ import annotations

import string
from typing import cast

import numpy as np
import pandas as pd
import pytest

from anndata._core.xarray import Dataset2D
from anndata.compat import XDataArray, XDataset, XVariable
from anndata.tests.helpers import gen_typed_df

pytest.importorskip("xarray")


@pytest.fixture
def df():
    return gen_typed_df(10)


@pytest.fixture
def dataset2d(df):
    return Dataset2D(XDataset.from_dataframe(df))


def test_shape(df, dataset2d):
    assert dataset2d.shape == df.shape


def test_columns(df, dataset2d):
    assert np.all(dataset2d.columns.sort_values() == df.columns.sort_values())


@pytest.mark.parametrize("same_columns", [True, False], ids=["same", "different"])
def test_columns_setter(df, dataset2d: Dataset2D, *, same_columns: bool):
    dataset2d_orig = dataset2d.copy()
    with (
        pytest.warns(
            UserWarning, match=r"Renaming or reordering columns on `Dataset2D`"
        )
        if same_columns
        else pytest.raises(ValueError, match=r"Trying to rename the keys")
    ):
        dataset2d.columns = (
            dataset2d.columns if same_columns else pd.Index(["not", "a", "column"])
        )
    assert dataset2d.equals(dataset2d_orig)


def test_to_memory(df, dataset2d):
    memory_df = dataset2d.to_memory()
    assert np.all(df == memory_df)
    assert np.all(df.index == memory_df.index)
    assert np.all(df.columns.sort_values() == memory_df.columns.sort_values())


def test_getitem(df, dataset2d):
    col = df.columns[0]
    assert np.all(dataset2d[col] == df[col])


def test_getitem_empty(df, dataset2d):
    empty_dset = dataset2d[[]]
    assert empty_dset.shape == (df.shape[0], 0)
    assert np.all(empty_dset.index == dataset2d.index)


def test_backed_property(dataset2d):
    assert not dataset2d.is_backed

    dataset2d.is_backed = True
    assert dataset2d.is_backed

    dataset2d.is_backed = False
    assert not dataset2d.is_backed


def test_true_index_dim_column_subset(dataset2d, df):
    col_iter = iter(dataset2d.keys())
    col = next(col_iter)
    dataset2d.true_index_dim = col

    # Ensure we can actually select columns properly that are not the index column
    cols = [next(col_iter), next(col_iter)]
    df_expected = dataset2d[cols].to_memory()
    # account for the fact that we manually set `true_index_dim`
    df.index = df[col]
    df.index.name = None
    pd.testing.assert_frame_equal(df_expected, df[cols])


def true_index_set_from_existing_col(dataset2D: Dataset2D) -> None:
    col = cast("str", next(iter(dataset2D.keys())))
    dataset2D.true_index_dim = col
    assert dataset2D.index_dim == "index"
    assert dataset2D.true_index_dim == col


def true_index_set_from_unknown_col(dataset2D: Dataset2D) -> None:
    with (
        pytest.raises(ValueError, match=r"Unknown variable `test`\."),
    ):
        dataset2D.true_index_dim = "test"


def true_index_set_unset(dataset2D: Dataset2D) -> None:
    dataset2D.true_index_dim = None
    assert dataset2D.true_index_dim == dataset2D.index_dim


def test_index_setting_from_existing_column(dataset2d: Dataset2D) -> None:
    col = cast("str", next(iter(dataset2d.keys())))
    old_index_dim = dataset2d.index_dim

    dataset2d.index = dataset2d[col]

    assert dataset2d.index_dim == col
    assert old_index_dim not in (
        *dataset2d.columns,
        dataset2d.index_dim,
        dataset2d.index.name,
    )


def test_index_setting_from_different_true_index_overrides_index(
    dataset2d: Dataset2D,
) -> None:
    col = cast("str", next(iter(dataset2d.keys())))
    dataset2d.true_index_dim = col
    old_index_dim = dataset2d.index_dim
    dataset2d.index = dataset2d[dataset2d.true_index_dim]
    assert dataset2d.index_dim == dataset2d.true_index_dim
    assert old_index_dim not in (
        *dataset2d.columns,
        dataset2d.index_dim,
        dataset2d.index.name,
    )


def test_set_index_from_outside_dataset(dataset2d: Dataset2D) -> None:
    alphabet = np.asarray([*string.ascii_letters, *string.digits, *string.punctuation])
    new_idx = set()
    while len(new_idx) < dataset2d.shape[0]:  # if we get duplicates, we just try again
        new_idx.add("".join(np.random.choice(alphabet, size=10)))
    new_idx = pd.Index(new_idx, name="test_index")

    col = next(iter(dataset2d.keys()))
    dataset2d.true_index_dim = col

    dataset2d.index = new_idx
    assert np.all(dataset2d.index == new_idx)
    assert dataset2d.index_dim == new_idx.name
    assert list(dataset2d.ds.coords.keys()) == [new_idx.name]


@pytest.fixture
def dataset_2d_one_column():
    return Dataset2D(
        XDataset(
            {"foo": ("obs_names", pd.array(["a", "b", "c"], dtype="category"))},
            coords={"obs_names": [1, 2, 3]},
        )
    )


def test_dataset_2d_set_dataarray(dataset_2d_one_column):
    da = XDataArray(
        np.arange(3), coords={"obs_names": [1, 2, 3]}, dims=("obs_names"), name="bar"
    )
    dataset_2d_one_column["bar"] = da
    assert dataset_2d_one_column["bar"].dims == ("obs_names",)
    assert dataset_2d_one_column["bar"].equals(da)


def test_dataset_2d_set_dataset(dataset_2d_one_column):
    ds = XDataset(
        data_vars={
            "foo": ("obs_names", np.arange(3)),
            "bar": ("obs_names", np.arange(3) + 3),
        },
        coords={"obs_names": [1, 2, 3]},
    )
    key = ["foo", "bar"]
    dataset_2d_one_column[key] = ds
    assert tuple(dataset_2d_one_column[key].ds.sizes.keys()) == ("obs_names",)
    assert dataset_2d_one_column[key].equals(ds)


@pytest.mark.parametrize(
    "setter",
    [
        pd.array(["e", "f", "g"], dtype="category"),
        ("obs_names", pd.array(["e", "f", "g"], dtype="category")),
    ],
    ids=["array", "tuple_with_array"],
)
def test_dataset_2d_set_extension_array(dataset_2d_one_column, setter):
    dataset_2d_one_column["bar"] = setter
    assert dataset_2d_one_column["bar"].dims == ("obs_names",)
    assert (
        dataset_2d_one_column["bar"].data is setter[1]
        if isinstance(setter, tuple)
        else setter
    )


@pytest.mark.parametrize(
    ("da", "pattern"),
    [
        pytest.param(
            XDataset(
                data_vars={"bar": ("obs_names", np.arange(3))},
                coords={"foo": ("obs_names", np.arange(3))},
            ),
            r"Dataset should have coordinate obs_names",
            id="coord_name_dataset",
        ),
        pytest.param(
            XDataArray(
                np.arange(3),
                coords={"foo": ("obs_names", np.arange(3))},
                dims="obs_names",
                name="bar",
            ),
            r"DataArray should have coordinate obs_names",
            id="coord_name",
        ),
        pytest.param(
            XDataArray(
                np.arange(3),
                coords={"obs_names": np.arange(3)},
                dims=("obs_names",),
                name="not_bar",
            ),
            r"DataArray should have name bar, found not_bar",
            id="dataarray_name",
        ),
        pytest.param(
            XDataset(
                data_vars={
                    "foo": (["obs_names", "not_obs_names"], np.arange(9).reshape(3, 3))
                },
                coords={"obs_names": np.arange(3), "not_obs_names": np.arange(3)},
            ),
            r"Dataset should have only one dimension",
            id="multiple_dims_dataset",
        ),
        pytest.param(
            XDataArray(
                np.arange(9).reshape(3, 3),
                coords={"obs_names": np.arange(3), "not_obs_names": np.arange(3)},
                dims=("obs_names", "not_obs_names"),
            ),
            r"DataArray should have only one dimension",
            id="multiple_dims_dataarray",
        ),
        pytest.param(
            XVariable(
                data=np.arange(9).reshape(3, 3),
                dims=("obs_names", "not_obs_names"),
            ),
            r"Variable should have only one dimension",
            id="multiple_dims_variable",
        ),
        pytest.param(
            XDataset(
                data_vars={"foo": ("other", np.arange(3))},
                coords={"obs_names": ("other", np.arange(3))},
            ),
            r"Dataset should have dimension obs_names",
            id="name_conflict_dataset",
        ),
        pytest.param(
            XVariable(
                data=np.arange(3),
                dims="not_obs_names",
            ),
            r"Variable should have dimension obs_names, found not_obs_names",
            id="name_conflict_variable",
        ),
        pytest.param(
            XDataArray(
                np.arange(3),
                coords=[np.arange(3)],
                dims="not_obs_names",
            ),
            r"DataArray should have dimension obs_names, found not_obs_names",
            id="name_conflict_dataarray",
        ),
        pytest.param(
            ("not_obs_names", [1, 2, 3]),
            r"Setting value tuple should have first entry",
            id="tuple_bad_dim",
        ),
        pytest.param(
            (("not_obs_names",), [1, 2, 3]),
            r"Dimension tuple should have only",
            id="nested_tuple_bad_dim",
        ),
        pytest.param(
            (("obs_names", "bar"), [1, 2, 3]),
            r"Dimension tuple is too long",
            id="nested_tuple_too_long",
        ),
    ],
)
def test_dataset_2d_set_with_bad_obj(da, pattern, dataset_2d_one_column):
    with pytest.raises(ValueError, match=pattern):
        dataset_2d_one_column["bar"] = da


@pytest.mark.parametrize(
    "data", [np.arange(3), XDataArray(np.arange(3), dims="obs_names", name="obs_names")]
)
def test_dataset_2d_set_index(data, dataset_2d_one_column):
    with pytest.raises(
        KeyError,
        match="Cannot set the index dimension obs_names",
    ):
        dataset_2d_one_column["obs_names"] = data


@pytest.mark.parametrize(
    ("ds", "pattern", "error"),
    [
        pytest.param(
            XDataset(
                {"foo": ("obs_names", pd.array(["a", "b", "c"], dtype="category"))},
                coords={"obs_names": ("not_obs_names", [1, 2, 3])},
            ),
            r"Dataset should have exactly one dimension",
            ValueError,
            id="more_than_one_dimension",
        ),
        pytest.param(
            XDataset(
                {"foo": ("obs_names", pd.array(["a", "b", "c"], dtype="category"))},
                coords={
                    "obs_names": ("obs_names", [1, 2, 3]),
                    "not_obs_names": ("obs_names", [1, 2, 3]),
                },
            ),
            r"Dataset should have exactly one coordinate",
            ValueError,
            id="more_than_one_coord",
        ),
        pytest.param(
            XDataset(
                {"foo": ("not_obs_names", pd.array(["a", "b", "c"], dtype="category"))},
                coords={
                    "obs_names": ("not_obs_names", [1, 2, 3]),
                },
            ),
            r"does not match coordinate",
            ValueError,
            id="coord_dim_mismatch",
        ),
        pytest.param(
            XDataset(
                {"foo": (("obs", "obs1"), np.arange(9).reshape(3, 3))},
                coords={
                    "obs_names": (("obs", "obs1"), np.arange(9).reshape(3, 3)),
                },
            ),
            r"Dataset should have exactly one",
            ValueError,
            id="multi_dim_coord",
        ),
        pytest.param(
            dict(foo="bar"),
            r"Expected an xarray Dataset",
            TypeError,
            id="non_ds_init",
        ),
    ],
)
def test_init_errors(ds, pattern, error):
    with pytest.raises(error, match=pattern):
        Dataset2D(ds)
