jax>=0.1.67
jaxlib>=0.1.47
optax
chex
distrax>=0.1.2
tensorflow-probability>=0.16.0
tqdm>=4.0.0
ml-collections==0.1.0
jaxtyping>=0.0.2

[cuda]
jax[cuda]

[dev]
black
isort
pylint
flake8
pytest
