use super::{
    deep_map, visit_circuit_fallable, Circuit, CircuitConstructionError, CircuitNode, CircuitRc,
};
use crate::circuit::deep_map_preorder;
use crate::circuit::deep_map_unwrap_preorder;
use crate::{circuit::visit_circuit, pycall};
use pyo3::{pyfunction, PyObject, Python};
use std::collections::{HashMap, HashSet};
pub fn filter_nodes(circuit: CircuitRc, filter: &dyn Fn(&Circuit) -> bool) -> HashSet<CircuitRc> {
    let mut result: HashSet<CircuitRc> = HashSet::new();
    visit_circuit(&circuit, &mut |circuit| {
        if filter(circuit) {
            result.insert(circuit.clone().rc());
        }
    });
    result
}

#[pyfunction]
#[pyo3(name = "filter_nodes")]
pub fn filter_nodes_py(
    circuit: CircuitRc,
    filter: PyObject,
) -> Result<HashSet<CircuitRc>, CircuitConstructionError> {
    let mut result: HashSet<CircuitRc> = HashSet::new();
    let err = visit_circuit_fallable(&circuit, &mut |circuit| {
        let filter_result: Result<_, CircuitConstructionError> =
            pycall!(filter, (circuit.clone().rc(),), CircuitConstructionError);
        let filter_result = filter_result?;
        if filter_result {
            result.insert(circuit.clone().rc());
        }
        Ok(())
    });
    err.map(|_x| result)
}

#[pyfunction]
pub fn replace_nodes(circuit: CircuitRc, map: HashMap<CircuitRc, CircuitRc>) -> CircuitRc {
    deep_map_unwrap_preorder(&circuit, &mut |x: &Circuit| -> CircuitRc {
        let rc = x.clone().rc();
        map.get(&rc).cloned().unwrap_or(rc)
    })
}

#[pyfunction]
#[pyo3(name = "deep_map_preorder")]
pub fn deep_map_preorder_py(
    circuit: CircuitRc,
    f: PyObject,
) -> Result<CircuitRc, CircuitConstructionError> {
    deep_map_preorder(&circuit, &mut |x: &Circuit| -> Result<
        CircuitRc,
        CircuitConstructionError,
    > {
        pycall!(f, (x.clone().rc(),), CircuitConstructionError)
    })
}

#[pyfunction]
#[pyo3(name = "deep_map")]
pub fn deep_map_py(circuit: CircuitRc, f: PyObject) -> Result<CircuitRc, CircuitConstructionError> {
    deep_map(&circuit, &mut |x: &Circuit| -> Result<
        CircuitRc,
        CircuitConstructionError,
    > {
        pycall!(f, (x.clone().rc(),), CircuitConstructionError)
    })
}

#[pyfunction]
#[pyo3(name = "update_nodes")]
pub fn update_nodes_py(
    circuit: CircuitRc,
    matcher: PyObject,
    updater: PyObject,
) -> Result<CircuitRc, CircuitConstructionError> {
    let nodes = filter_nodes_py(circuit.clone(), matcher)?;
    deep_map_preorder(&circuit, &mut |x| {
        let xcloned = x.clone().rc();
        if nodes.contains(&xcloned) {
            pycall!(updater, (xcloned,), CircuitConstructionError)
        } else {
            Ok(xcloned)
        }
    })
}

pub type CircuitPath = Vec<usize>;

#[pyfunction]
pub fn path_get(circuit: CircuitRc, path: CircuitPath) -> Option<CircuitRc> {
    let mut cur = circuit;
    for i in path {
        let children: Vec<CircuitRc> = cur.children().collect();
        if i >= children.len() {
            return None;
        }
        cur = children[i].clone()
    }
    Some(cur)
}

pub fn update_path<F>(
    circuit: &Circuit,
    path: &CircuitPath,
    updater: &mut F,
) -> Result<CircuitRc, CircuitConstructionError>
where
    F: Fn(&Circuit) -> Result<CircuitRc, CircuitConstructionError>,
{
    fn recurse<F>(
        circuit: &Circuit,
        path: &CircuitPath,
        path_idx: usize,
        updater: &mut F,
    ) -> Result<CircuitRc, CircuitConstructionError>
    where
        F: Fn(&Circuit) -> Result<CircuitRc, CircuitConstructionError>,
    {
        if path_idx == path.len() {
            return updater(circuit);
        }
        circuit
            .map_children_enumerate(&mut |i, circuit| {
                if i == path[path_idx] {
                    recurse(circuit, path, path_idx + 1, updater)
                } else {
                    Ok(circuit.clone().rc())
                }
            })
            .map(|z| z.rc())
    }
    recurse(circuit, path, 0, updater)
}

#[pyfunction]
#[pyo3(name = "update_path")]
pub fn update_path_py(
    circuit: CircuitRc,
    path: CircuitPath,
    updater: PyObject,
) -> Result<CircuitRc, CircuitConstructionError> {
    update_path(&circuit, &path, &mut |x| {
        pycall!(updater, (x.clone().rc(),), CircuitConstructionError)
    })
}
