use std::{collections::HashSet, iter::zip};

use crate::{
    circuit::Add,
    tensor_util::{TensorAxisIndex, TensorIndex},
    union_find::UnionFind,
    unwrap,
    util::{dedup_with_order, is_unique, EinInt},
};

use super::{
    algebraic_rewrite::add_make_broadcasts_explicit, Circuit, CircuitNode, Concat, Einsum, Index,
    Rearrange,
};

use crate::pyo3_prelude::*;

/// yay! fun when you find a linear time version of a thing you did super slow bc you're lazy
/// transpose so you have ints used by each diag on each axis, hash.
pub fn diags_intersection(diags: &Vec<Vec<EinInt>>) -> Vec<EinInt> {
    assert!(!diags.is_empty() && diags.iter().all(|x| x.len() == diags[0].len()));
    let transpose: Vec<Vec<EinInt>> = (0..diags[0].len())
        .map(|i| diags.iter().map(|d| d[i]).collect::<Vec<_>>())
        .collect();
    let deduped: Vec<&[EinInt]> = dedup_with_order(
        &transpose
            .iter()
            .map(|x: &Vec<u8>| &x[..])
            .collect::<Vec<_>>(),
    );
    transpose
        .iter()
        .map(|x| deduped.iter().position(|z| *z == &x[..]).unwrap() as u8)
        .collect()
}

pub fn diags_union(diags: &Vec<Vec<EinInt>>) -> Vec<EinInt> {
    let mut uf = UnionFind::new(diags[0].len());
    for d in diags {
        for (i1, int1) in d.iter().enumerate() {
            for (i2, int2) in d.iter().enumerate() {
                if i1 < i2 && int1 == int2 {
                    println!("{} {}", i1, i2);
                    uf.union(i1 as usize, i2 as usize);
                }
            }
        }
    }
    (0..diags[0].len()).map(|i| uf.find(i) as u8).collect()
}

// some stuff in here should likely be extracted to helpers
#[pyfunction]
pub fn einsum_push_down_trace(einsum: &Einsum) -> Option<Einsum> {
    let mut did_anything = false;
    let new_args = einsum
        .args
        .iter()
        .map(|(node, ints)| {
            if !is_unique(ints) {
                match &***node {
                    Circuit::Add(child) => {
                        // have to rearrange
                        let explicit_add =
                            add_make_broadcasts_explicit(child).unwrap_or(child.clone());
                        let new_add = Add::try_new(
                            explicit_add
                                .nodes
                                .iter()
                                .map(|node| {
                                    Einsum::new_trace(node.clone(), ints.clone(), None).rc()
                                })
                                .collect(),
                            None,
                        )
                        .unwrap()
                        .rc();
                        did_anything = true;
                        return (new_add, dedup_with_order(ints));
                    }
                    Circuit::Rearrange(child) => {
                        let axis_map = &child.child_axis_map()[0];
                        let ints_for_intersection = (0..child.info().rank())
                            .map(|i| {
                                if axis_map.iter().any(|z| *z == Some(i)) {
                                    0
                                } else {
                                    i as u8 + 1
                                }
                            })
                            .collect();
                        let intersected_ints =
                            diags_intersection(&vec![ints.clone(), ints_for_intersection]);
                        let intersected_deduped = dedup_with_order(&intersected_ints);
                        if intersected_ints.len() != intersected_deduped.len() {
                            let tuples_to_remove: HashSet<Box<[EinInt]>> =
                                zip(intersected_ints.iter().enumerate(), &child.spec.output_ints)
                                    .filter_map(|((i, trace_int), oints)| {
                                        if intersected_ints
                                            .iter()
                                            .position(|z| z == trace_int)
                                            .unwrap()
                                            != i
                                        {
                                            Some(oints[..].into())
                                        } else {
                                            None
                                        }
                                    })
                                    .collect();
                            let thingy: Vec<EinInt> = axis_map
                                .iter()
                                .enumerate()
                                .map(|(i, x)| {
                                    if let Some(z) = x {
                                        intersected_ints[*z]
                                    } else {
                                        i as u8 + child.info().rank() as u8
                                    }
                                })
                                .collect();
                            // can't just use new_trace bc which one gets eliminated is determined by output, not input
                            // so for instance `aba` might go to `ba` instead of normal `ab`
                            let new_trace = Einsum::try_new(
                                vec![(child.node.clone(), thingy.clone())],
                                (0..child.node.info().rank())
                                    .filter(|i| {
                                        !tuples_to_remove.contains(&child.spec.input_ints[*i][..])
                                    })
                                    .map(|i| thingy[i])
                                    .collect(),
                                None,
                            )
                            .unwrap()
                            .rc();

                            let new_rearrange = Rearrange::try_new(
                                new_trace,
                                child.spec.filter_all_tuples(&tuples_to_remove),
                                None,
                            )
                            .unwrap()
                            .rc();
                            let up_axes = intersected_deduped
                                .iter()
                                .map(|i| {
                                    ints[intersected_ints.iter().position(|z| z == i).unwrap()]
                                })
                                .collect();
                            did_anything = true;
                            return (new_rearrange, up_axes);
                        }
                    }
                    Circuit::Index(child) => {
                        let ints_for_intersection = zip(&child.index.0, &child.node.info().shape)
                            .enumerate()
                            .filter_map(|(i, (idx, l))| {
                                if matches!(idx, TensorAxisIndex::Single(_)) {
                                    return None;
                                }
                                Some(if idx.is_identity(*l) {
                                    0_u8
                                } else {
                                    (i + 1) as u8
                                })
                            })
                            .collect();
                        let intersected_ints =
                            diags_intersection(&vec![ints_for_intersection, ints.clone()]);
                        let intersected_deduped = dedup_with_order(&intersected_ints);
                        if intersected_ints.len() != intersected_deduped.len() {
                            did_anything = true;
                            let child_map = &child.child_axis_map()[0];
                            let intersected_ints_input: Vec<u8> = child_map
                                .iter()
                                .enumerate()
                                .map(|(i, x)| {
                                    if let Some(z) = x {
                                        intersected_ints[*z] as u8
                                    } else {
                                        (i + child.node.info().rank()) as u8
                                    }
                                })
                                .collect();
                            let intersected_ints_input_deduped =
                                dedup_with_order(&intersected_ints_input);
                            let new_trace = Einsum::new_trace(
                                child.node.clone(),
                                intersected_ints_input.clone(),
                                None,
                            )
                            .rc();
                            let up_axes = intersected_deduped
                                .iter()
                                .map(|i| {
                                    ints[intersected_ints.iter().position(|z| z == i).unwrap()]
                                })
                                .collect();
                            return (
                                Index::try_new(
                                    new_trace,
                                    TensorIndex(
                                        intersected_ints_input_deduped
                                            .iter()
                                            .map(|i| {
                                                child.index.0[intersected_ints_input
                                                    .iter()
                                                    .position(|z| z == i)
                                                    .unwrap()]
                                                .clone()
                                            })
                                            .collect(),
                                    ),
                                    None,
                                )
                                .unwrap()
                                .rc(),
                                up_axes,
                            );
                        }
                    }
                    Circuit::Concat(child) => {
                        let mut new_raw_ints = ints.clone();
                        let concat_axis_int = new_raw_ints.iter().max().unwrap_or(&0) + 1;
                        new_raw_ints[child.axis] = concat_axis_int;
                        let deduped_here = dedup_with_order(&new_raw_ints);
                        if deduped_here.len() != new_raw_ints.len() {
                            let new_concat_axis = deduped_here
                                .iter()
                                .position(|x| *x == concat_axis_int)
                                .unwrap();
                            let new_concat = Concat::try_new(
                                child
                                    .nodes
                                    .iter()
                                    .map(|x| {
                                        Einsum::new_trace(x.clone(), new_raw_ints.clone(), None)
                                            .rc()
                                    })
                                    .collect(),
                                new_concat_axis,
                                None,
                            )
                            .unwrap()
                            .rc();
                            let up_axes = deduped_here
                                .iter()
                                .map(|x| {
                                    if *x == concat_axis_int {
                                        ints[child.axis]
                                    } else {
                                        *x
                                    }
                                })
                                .collect();
                            did_anything = true;
                            return (new_concat, up_axes);
                        }
                    }
                    _ => {}
                }
            }
            (node.clone(), ints.clone())
        })
        .collect();
    if !did_anything {
        return None;
    }
    Some(Einsum::try_new(new_args, einsum.out_axes.clone(), einsum.name_cloned()).unwrap())
}

#[test]
fn test_diags_intersection_union() {
    let ex = vec![vec![0, 1, 0, 0], vec![0, 1, 0, 1]];
    let inter = diags_intersection(&ex);
    dbg!(&inter);
    let union = diags_union(&ex);
    dbg!(&union);
}

#[pyfunction]
pub fn add_pull_diags(add: &Add) -> Option<Einsum> {
    let _rank = add.info().rank();
    if !add
        .nodes
        .iter()
        .all(|x| matches!(&***x, Circuit::Einsum(_)))
    {
        return None;
    }
    let diags = add
        .nodes_and_rank_differences()
        .iter()
        .map(|(node, rank_difference)| {
            let einsum = unwrap!(&***node, Circuit::Einsum);
            (0..*rank_difference as u8)
                .chain(zip(&einsum.info().shape, &einsum.out_axes).enumerate().map(
                    |(i, (l, x))| {
                        if *l > 1 {
                            x + *rank_difference as u8
                        } else {
                            (i + *rank_difference + *einsum.out_axes.iter().max().unwrap() as usize)
                                as u8
                        }
                    },
                ))
                .collect()
        })
        .collect();
    let overall_diags = diags_intersection(&diags);
    let overall_diags_deduped = dedup_with_order(&overall_diags);
    if overall_diags_deduped.len() == overall_diags.len() {
        return None;
    }
    let new_add = Add::try_new(
        add.nodes_and_rank_differences()
            .iter()
            .map(|(node, rdif)| {
                let ein = unwrap!(&***node, Circuit::Einsum);
                Einsum::try_new(
                    ein.args.clone(),
                    dedup_with_order(&overall_diags[*rdif..])
                        .iter()
                        .map(|i| {
                            ein.out_axes
                                [overall_diags[*rdif..].iter().position(|z| z == i).unwrap()]
                        })
                        .collect(),
                    None,
                )
                .unwrap()
                .rc()
            })
            .collect(),
        add.name_cloned(),
    )
    .unwrap()
    .rc();
    Some(Einsum::new_diag(new_add, overall_diags, None))
}
