einops>=0.3
flax
jax
jaxlib
