jax>=0.2.14
jaxlib>=0.1.67
jax_dataclasses>=1.4.4
numpy
overrides!=4
tyro

[testing]
mypy
jax!=0.3.19
flax
hypothesis[numpy]
pytest
pytest-xdist[psutil]
pytest-cov
