jax==0.3.5
jaxlib==0.3.5
optax>=0.1.0
chex==0.1.3
distrax>=0.1.2
tensorflow-probability==0.16.0
tqdm>=4.0.0
ml-collections==0.1.0
protobuf==3.19.0
jaxtyping

[cuda]
jax[cuda]

[dev]
black
isort
pylint
flake8
pytest
