jax>=0.4.0
jaxlib>=0.4.0
jaxtyping
ml-collections==0.1.0
distrax>=0.1.2

[cuda]
jax[cuda]

[dev]
black
isort
pylint
flake8
pytest
pytest-cov
