jax>=0.2.5
jaxlib>=0.1.57
numpy>=1.18.5
multipledispatch==0.6.0
packaging==20.4
chex>=0.0.4
tfp-nightly==0.12.0.dev20201123
ml-collections==0.1.0
numpyro==0.6.0

[dev]
black
isort
pylint
flake8

[docs]
furo==2020.12.30b24
nbsphinx==0.8.1
nb-black==1.0.7
matplotlib==3.3.3
sphinx-copybutton==0.3.5

[tests]
pytest
