numpy>=1.23.3

[jax]
jax>=0.4.1
jaxlib>=0.4.1

[pytorch]
torch>=1.12.1
