numpy>=1.12
jax>=0.1.59
matplotlib
dataclasses
msgpack

[testing]
pytest
pytest-xdist
tensorflow_datasets
