Compute OT with pymot: tutorial =============================== This tutorial aims at demonstrating how to use the optimal transport ``ot`` module of ``pymot``. A Jupyter notebook version of this tutorial will be available soon. Load ``pymot`` and necessary modules ------------------------------------ .. code:: python3 import torch import pymot.moments import pymot.ot dtype=torch.float64 # double float precision device = torch.device('cpu') # or cuda if available (Sliced-)Wasserstein distance ----------------------------- Let :math:`\mu_1,\mu_2` be two probability measures defined on :math:`\mathbb{R}` and :math:`p\geq1` be some scalar. One may be interested in computing (a particular instance of) the Wasserstein distance defined as .. math:: :label: eq:def:wasserstein \texttt{W}_p(\mu_1,\mu_2) \triangleq \inf_{\pi} \int_{\mathbb{R}\times \mathbb{R}} |x,-y|^p \,\mathrm{d}\pi(x,y) where :math:`\pi` is a joint probability measure with marginal :math:`\mu_1,\mu_2`. Alternatively , when :math:`\mu_1,\mu_2` are probability measures defined on :math:`\mathbb{R^d} with :math:`d\geq1`, `pymot` allows to compute the Sliced-Wasserstein distance defined (see [] for a more rigorous definition) .. math:: :label: eq:def:sliced-wasserstein \texttt{SW}_p(\mu_1,\mu_2) \triangleq \int_{\mathbb{S}_{d-1}} \texttt{W}_p(\mathrm{P}_{\theta\#}\mu_1, \mathrm{P}_{\theta\#}\mu_2) \, \mathrm{d}\theta where :math:`\mathbb{S}_{d-1}` denotes the :math:`d`-dimensional unit sphere and :math:`\mathrm{P}_{\theta\#}` the projector operator onto direction :math:`\theta`. Note that the SW distance coincide with the W distance when :math:`d=1`. (Sliced-)Wasserstein distance with ``pymot`` -------------------------------------------- ``pymot`` provides several procedures to compute the (Sliced-) Wasserstein distance from samples / moments. Let us first generate a dataset .. code:: python3 num_slices = 10 num_samples = 1000 num_moments = 20 p = 2. X_1 = torch.randn([num_slices, num_moments], dtype=dtype, device=device) X_2 = 5 + torch.randn([num_slices, num_moments], dtype=dtype, device=device) Case 1: distances from samples ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Assume first that samples from the two distributions are available. The SW distance :eq:`eq:def:sliced-wasserstein` can be computed with :code:`sliced-wasserstein` .. code:: python3 wass = pymot.ot.sliced_wasserstein(X_1, X_2, p, dtype=dtype, device=device) # Expected value ~5^2 print(wass) Case 2: One distribution is only known in moments ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Assume that only a set of moments from the second distribution is available. The SW distance :eq:`eq:def:sliced-wasserstein` can still be computed with :code:`sliced-wasserstein` .. code:: python3 moments_2 = pymot.moments.estimate_moments(X_2, num_moments, type_moments, dtype=dtype) wass = pymot.ot.sliced_wasserstein(X_1, moments_2, p, dtype=dtype, device=device) # Expected value ~5^2 print(wass) Case 3: Only moments are available ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Assume that only only moments for the two distributions. The SW distance :eq:`eq:def:sliced-wasserstein` can still be computed with :code:`sliced-wasserstein` .. code:: python3 moments_1 = pymot.moments.estimate_moments(X_1, num_moments, type_moments, dtype=dtype) moments_2 = pymot.moments.estimate_moments(X_2, num_moments, type_moments, dtype=dtype) wass = pymot.ot.sliced_wasserstein(moments_1, moments_2, p, dtype=dtype, device=device) # Expected value ~5^2 print(wass)