Compute OT with pymot: tutorial¶
This tutorial aims at demonstrating how to use the optimal transport ot module of pymot.
A Jupyter notebook version of this tutorial will be available soon.
Load pymot and necessary modules¶
import torch
import pymot.moments
import pymot.ot
dtype=torch.float64 # double float precision
device = torch.device('cpu') # or cuda if available
(Sliced-)Wasserstein distance¶
Let \(\mu_1,\mu_2\) be two probability measures defined on \(\mathbb{R}\) and \(p\geq1\) be some scalar. One may be interested in computing (a particular instance of) the Wasserstein distance defined as
(1)¶\[\texttt{W}_p(\mu_1,\mu_2) \triangleq \inf_{\pi} \int_{\mathbb{R}\times \mathbb{R}} |x,-y|^p \,\mathrm{d}\pi(x,y)\]
where \(\pi\) is a joint probability measure with marginal \(\mu_1,\mu_2\).
Alternatively , when \(\mu_1,\mu_2\) are probability measures defined on \(\mathbb{R^d} with :math:\), pymot allows to compute the Sliced-Wasserstein distance defined (see [] for a more rigorous definition)
(2)¶\[\texttt{SW}_p(\mu_1,\mu_2) \triangleq \int_{\mathbb{S}_{d-1}} \texttt{W}_p(\mathrm{P}_{\theta\#}\mu_1, \mathrm{P}_{\theta\#}\mu_2) \, \mathrm{d}\theta\]
where \(\mathbb{S}_{d-1}\) denotes the \(d\)-dimensional unit sphere and \(\mathrm{P}_{\theta\#}\) the projector operator onto direction \(\theta\). Note that the SW distance coincide with the W distance when \(d=1\).
(Sliced-)Wasserstein distance with pymot¶
pymot provides several procedures to compute the (Sliced-) Wasserstein distance from samples / moments.
Let us first generate a dataset
num_slices = 10
num_samples = 1000
num_moments = 20
p = 2.
X_1 = torch.randn([num_slices, num_moments], dtype=dtype, device=device)
X_2 = 5 + torch.randn([num_slices, num_moments], dtype=dtype, device=device)
Case 1: distances from samples¶
Assume first that samples from the two distributions are available.
The SW distance (2) can be computed with sliced-wasserstein
wass = pymot.ot.sliced_wasserstein(X_1, X_2, p,
dtype=dtype, device=device)
# Expected value ~5^2
print(wass)
Case 2: One distribution is only known in moments¶
Assume that only a set of moments from the second distribution is available.
The SW distance (2) can still be computed with sliced-wasserstein
moments_2 = pymot.moments.estimate_moments(X_2,
num_moments,
type_moments,
dtype=dtype)
wass = pymot.ot.sliced_wasserstein(X_1, moments_2, p,
dtype=dtype, device=device)
# Expected value ~5^2
print(wass)
Case 3: Only moments are available¶
Assume that only only moments for the two distributions.
The SW distance (2) can still be computed with sliced-wasserstein
moments_1 = pymot.moments.estimate_moments(X_1,
num_moments,
type_moments,
dtype=dtype)
moments_2 = pymot.moments.estimate_moments(X_2,
num_moments,
type_moments,
dtype=dtype)
wass = pymot.ot.sliced_wasserstein(moments_1, moments_2, p,
dtype=dtype, device=device)
# Expected value ~5^2
print(wass)