numpy>=1.23.3
scipy>=1.9.1

[jax]
jax>=0.4.1
jaxlib>=0.4.1
optax>=0.1.4

[pytorch]
torch>=2.0.0
