Metadata-Version: 2.1
Name: ott-jax
Version: 0.3.0
Summary: OTT: Optimal Transport Tools in Jax.
Home-page: https://github.com/ott-jax/ott
Author-email: optimal.transport.tools@gmail.com
License: Apache 2.0
Project-URL: Documentation, https://ott-jax.readthedocs.io
Project-URL: Source Code, https://github.com/ott-jax/ott
Keywords: optimal transport,sinkhorn,wasserstein,jax
Classifier: Development Status :: 5 - Production/Stable
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Natural Language :: English
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Operating System :: POSIX :: Linux
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Operating System :: Microsoft :: Windows
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Typing :: Typed
Requires-Python: >=3.7
Description-Content-Type: text/markdown
Provides-Extra: test
Provides-Extra: experimental
Provides-Extra: docs
Provides-Extra: dev
License-File: LICENSE

<img src="https://raw.githubusercontent.com/ott-jax/ott/main/docs/_static/images/logoOTT.png" width="10%" alt="logo">

# Optimal Transport Tools (OTT)
[![Downloads](https://pepy.tech/badge/ott-jax)](https://pypi.org/project/ott-jax/)
[![Tests](https://img.shields.io/github/workflow/status/ott-jax/ott/tests/main)](https://github.com/ott-jax/ott/actions/workflows/tests.yml)
[![Docs](https://img.shields.io/readthedocs/ott-jax/latest)](https://ott-jax.readthedocs.io/en/latest/)
[![Coverage](https://img.shields.io/codecov/c/github/ott-jax/ott/main)](https://app.codecov.io/gh/ott-jax/ott)

**See the [full documentation](https://ott-jax.readthedocs.io/en/latest/).**

## What is OTT-JAX?
A JAX powered library to compute optimal transport at scale and on accelerators, OTT-JAX includes the fastest
implementation of the Sinkhorn algorithm you will find around. We have implemented all tweaks (scheduling,
acceleration, initializations) and extensions (low-rank), that can be used directly, or within more advanced problems
(Gromov-Wasserstein, barycenters). Some of JAX features, including
[JIT](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Using-jit-to-speed-up-functions),
[auto-vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap) and
[implicit differentiation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)
work towards the goal of having end-to-end differentiable outputs. OTT-JAX is developed by a team of researchers
from Apple, Google, Meta and many academic contributors, including TU München, Oxford, ENSAE/IP Paris and the
Hebrew University.

## What is optimal transport?
Optimal transport can be loosely described as the branch of mathematics and optimization that studies
*matching problems*: given two families of points, and a cost function on pairs of points, find a `good' (low cost) way
to associate bijectively to every point in the first family another in the second.

Such problems appear in all areas of science, are easy to describe, yet hard to solve. Indeed, while matching optimally
two sets of *n* points using a pairwise cost can be solved with the
[Hungarian algorithm](https://en.wikipedia.org/wiki/Hungarian_algorithm), solving it costs an order of $O(n^3)$
operations, and lacks flexibility, since one may want to couple families of different sizes.

Optimal transport extends all of this, through faster algorithms (in $n^2$ or even linear in $n$) along with numerous
generalizations that can help it handle weighted sets of different size, partial matchings, and even more evolved
so-called quadratic matching problems.

In the simple toy example below, we compute the optimal coupling matrix between two point clouds sampled randomly
(2D vectors, compared with the squared Euclidean distance):

## Example
```python
import jax
import jax.numpy as jnp
from ott.tools import transport
# Samples two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0),4)
n, m, d = 12, 14, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a, b = a / jnp.sum(a), b / jnp.sum(b)
# Computes the couplings using the Sinkhorn algorithm.
ot = transport.solve(x, y, a=a, b=b)
P = ot.matrix
```

The call to `solve` above works out the optimal transport solution. The `ot` object contains a transport matrix
(here of size $12\times 14$) that quantifies a `link strength` between each point of the first point cloud, to one or
more points from the second, as illustrated in the plot below. In this toy example, most choices were arbitrary, and
are reflected in the crude `solve` API. We provide far more flexibility to define custom cost functions, objectives,
and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/).

![obtained coupling](https://raw.githubusercontent.com/ott-jax/ott/main/images/couplings.png)

## Citation
If you have found this work useful, please consider citing this reference:

```
@article{cuturi2022optimal,
  title={Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein},
  author={Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and
          Davis, Geoff and Teboul, Olivier},
  journal={arXiv preprint arXiv:2201.12324},
  year={2022}
}
```
