use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use pyo3::types::{PyDict, PyTuple};
use pyo3::{py_run, wrap_pyfunction, AsPyRef, PyCell};

mod common;

#[pyclass]
struct MutRefArg {
    n: i32,
}

#[pymethods]
impl MutRefArg {
    fn get(&self) -> PyResult<i32> {
        Ok(self.n)
    }
    fn set_other(&self, mut other: PyRefMut<MutRefArg>) -> PyResult<()> {
        other.n = 100;
        Ok(())
    }
}

#[test]
fn mut_ref_arg() {
    let gil = Python::acquire_gil();
    let py = gil.python();
    let inst1 = Py::new(py, MutRefArg { n: 0 }).unwrap();
    let inst2 = Py::new(py, MutRefArg { n: 0 }).unwrap();

    let d = [("inst1", &inst1), ("inst2", &inst2)].into_py_dict(py);

    py.run("inst1.set_other(inst2)", None, Some(d)).unwrap();
    let inst2 = inst2.as_ref(py).borrow();
    assert_eq!(inst2.n, 100);
}

#[pyclass]
struct PyUsize {
    #[pyo3(get)]
    pub value: usize,
}

#[pyfunction]
fn get_zero() -> PyResult<PyUsize> {
    Ok(PyUsize { value: 0 })
}

#[test]
/// Checks that we can use return a custom class in arbitrary function and use those functions
/// both in rust and python
fn return_custom_class() {
    let gil = Python::acquire_gil();
    let py = gil.python();

    // Using from rust
    assert_eq!(get_zero().unwrap().value, 0);

    // Using from python
    let get_zero = wrap_pyfunction!(get_zero)(py);
    py_assert!(py, get_zero, "get_zero().value == 0");
}

#[test]
fn intopytuple_primitive() {
    let gil = Python::acquire_gil();
    let py = gil.python();

    let tup = (1, 2, "foo");
    py_assert!(py, tup, "tup == (1, 2, 'foo')");
    py_assert!(py, tup, "tup[0] == 1");
    py_assert!(py, tup, "tup[1] == 2");
    py_assert!(py, tup, "tup[2] == 'foo'");
}

#[pyclass]
struct SimplePyClass {}

#[test]
fn intopytuple_pyclass() {
    let gil = Python::acquire_gil();
    let py = gil.python();

    let tup = (
        PyCell::new(py, SimplePyClass {}).unwrap(),
        PyCell::new(py, SimplePyClass {}).unwrap(),
    );
    py_assert!(py, tup, "type(tup[0]).__name__ == 'SimplePyClass'");
    py_assert!(py, tup, "type(tup[0]).__name__ == type(tup[1]).__name__");
    py_assert!(py, tup, "tup[0] != tup[1]");
}

#[test]
fn pytuple_primitive_iter() {
    let gil = Python::acquire_gil();
    let py = gil.python();

    let tup = PyTuple::new(py, [1u32, 2, 3].iter());
    py_assert!(py, tup, "tup == (1, 2, 3)");
}

#[test]
fn pytuple_pyclass_iter() {
    let gil = Python::acquire_gil();
    let py = gil.python();

    let tup = PyTuple::new(
        py,
        [
            PyCell::new(py, SimplePyClass {}).unwrap(),
            PyCell::new(py, SimplePyClass {}).unwrap(),
        ]
        .iter(),
    );
    py_assert!(py, tup, "type(tup[0]).__name__ == 'SimplePyClass'");
    py_assert!(py, tup, "type(tup[0]).__name__ == type(tup[0]).__name__");
    py_assert!(py, tup, "tup[0] != tup[1]");
}

#[pyclass(dict, module = "test_module")]
struct PickleSupport {}

#[pymethods]
impl PickleSupport {
    #[new]
    fn new() -> PickleSupport {
        PickleSupport {}
    }

    pub fn __reduce__<'py>(
        slf: &'py PyCell<Self>,
        py: Python<'py>,
    ) -> PyResult<(PyObject, &'py PyTuple, PyObject)> {
        let cls = slf.to_object(py).getattr(py, "__class__")?;
        let dict = slf.to_object(py).getattr(py, "__dict__")?;
        Ok((cls, PyTuple::empty(py), dict))
    }
}

fn add_module(py: Python, module: &PyModule) -> PyResult<()> {
    py.import("sys")?
        .dict()
        .get_item("modules")
        .unwrap()
        .downcast::<PyDict>()?
        .set_item(module.name()?, module)
}

#[test]
fn test_pickle() {
    let gil = Python::acquire_gil();
    let py = gil.python();
    let module = PyModule::new(py, "test_module").unwrap();
    module.add_class::<PickleSupport>().unwrap();
    add_module(py, module).unwrap();
    let inst = PyCell::new(py, PickleSupport {}).unwrap();
    py_run!(
        py,
        inst,
        r#"
        inst.a = 1
        assert inst.__dict__ == {'a': 1}

        import pickle
        inst2 = pickle.loads(pickle.dumps(inst))

        assert inst2.__dict__ == {'a': 1}
    "#
    );
}

#[test]
fn incorrect_iter() {
    let gil = Python::acquire_gil();
    let py = gil.python();
    let int = 13isize.to_object(py);
    let int_ref = int.as_ref(py);
    // Should not segfault.
    assert!(int_ref.iter().is_err());
    assert!(py
        .eval("print('Exception state should not be set.')", None, None)
        .is_ok());
}
