absl-py>=0.7.0
jax>=0.1.67
jaxlib>=0.1.47
numpy>=1.18.4
matplotlib>=2.0.1
flax>=0.3.6
optax>=0.1.1
tqdm>=4.63.0
