Compute OT with pymot: tutorial

This tutorial aims at demonstrating how to use the optimal transport ot module of pymot. A Jupyter notebook version of this tutorial will be available soon.

Load pymot and necessary modules

import torch

import pymot.moments
import pymot.ot

dtype=torch.float64 # double float precision
device = torch.device('cpu') # or cuda if available

(Sliced-)Wasserstein distance

Let \(\mu_1,\mu_2\) be two probability measures defined on \(\mathbb{R}\) and \(p\geq1\) be some scalar. One may be interested in computing (a particular instance of) the Wasserstein distance defined as

(1)\[\texttt{W}_p(\mu_1,\mu_2) \triangleq \inf_{\pi} \int_{\mathbb{R}\times \mathbb{R}} |x,-y|^p \,\mathrm{d}\pi(x,y)\]

where \(\pi\) is a joint probability measure with marginal \(\mu_1,\mu_2\).

Alternatively , when \(\mu_1,\mu_2\) are probability measures defined on \(\mathbb{R^d} with :math:\), pymot allows to compute the Sliced-Wasserstein distance defined (see [] for a more rigorous definition)

(2)\[\texttt{SW}_p(\mu_1,\mu_2) \triangleq \int_{\mathbb{S}_{d-1}} \texttt{W}_p(\mathrm{P}_{\theta\#}\mu_1, \mathrm{P}_{\theta\#}\mu_2) \, \mathrm{d}\theta\]

where \(\mathbb{S}_{d-1}\) denotes the \(d\)-dimensional unit sphere and \(\mathrm{P}_{\theta\#}\) the projector operator onto direction \(\theta\). Note that the SW distance coincide with the W distance when \(d=1\).

(Sliced-)Wasserstein distance with pymot

pymot provides several procedures to compute the (Sliced-) Wasserstein distance from samples / moments.

Let us first generate a dataset

num_slices   = 10
num_samples  = 1000
num_moments  = 20
p = 2.

X_1 = torch.randn([num_slices, num_moments], dtype=dtype, device=device)
X_2 = 5 + torch.randn([num_slices, num_moments], dtype=dtype, device=device)

Case 1: distances from samples

Assume first that samples from the two distributions are available. The SW distance (2) can be computed with sliced-wasserstein

wass = pymot.ot.sliced_wasserstein(X_1, X_2, p,
                                   dtype=dtype, device=device)

# Expected value ~5^2
print(wass)

Case 2: One distribution is only known in moments

Assume that only a set of moments from the second distribution is available. The SW distance (2) can still be computed with sliced-wasserstein

moments_2 = pymot.moments.estimate_moments(X_2,
                                           num_moments,
                                           type_moments,
                                           dtype=dtype)

wass = pymot.ot.sliced_wasserstein(X_1, moments_2, p,
                                   dtype=dtype, device=device)

# Expected value ~5^2
print(wass)

Case 3: Only moments are available

Assume that only only moments for the two distributions. The SW distance (2) can still be computed with sliced-wasserstein

moments_1 = pymot.moments.estimate_moments(X_1,
                                           num_moments,
                                           type_moments,
                                           dtype=dtype)

moments_2 = pymot.moments.estimate_moments(X_2,
                                           num_moments,
                                           type_moments,
                                           dtype=dtype)

wass = pymot.ot.sliced_wasserstein(moments_1, moments_2, p,
                                   dtype=dtype, device=device)

# Expected value ~5^2
print(wass)