jax>=0.2.14
jaxlib>=0.1.67
jax_dataclasses>=1.2.0
numpy
overrides!=4

[testing]
flax
hypothesis[numpy]
pytest
pytest-xdist[psutil]
pytest-cov
