Introduction to the minibatch Wasserstein distance properties

Learning with minibatch Wasserstein

What happens when we use the Wasserstein distance with minibatches ? Results from the paper Learning with minibatch Wasserstein : asymptotic and gradient properties, published at the conference AISTATS 2020.

Towards Data Science
10 min readApr 1, 2020

--

Optimal transport has become very popular for many machine learning applications such as generative modelling [1] or domain adaptation [2]. In these applications, one wants to minimize a statistical distance between source data and target data. For this purpose, Wasserstein distance has become a basic asset. It can be either computed with the primal [2, 3] or the dual formulation [1] and rely on minibatches for optimization. Unfortunately the dual might lead to numerical instabilities because of the gradient computation of a continuous function, hence the use of the primal formulation. A review of the pros and cons of each formulation for learning purpose can be found in [section 9.4, 4].

However, computing the primal Wasserstein distance between minibatches is not equivalent to computing it between the full measures. In this story, I will describe the minibatch Wasserstein distance, where the minibatch paradigm consequences on the loss were left unjustified. A full review of the presented results can be found in our AISTATS2020 paper [5].

Disclaimer : for sake of simplicity, I will present results for the Wasserstein distance but all of them can be extended for all optimal transport variants. We also consider general ground cost. Furthermore, we will not make distinctions between set of elements and their measures.

Wasserstein Distance

Based on the Kantorovich problem, the Wasserstein distance measures the distance between two distributions by seeking the minimal displacement cost between the measure α and the measure β, according to a ground metric C. Let α (size n) and β (size n) be two discrete bounded uniform measures and let C be a metric (size n × n). The Wasserstein distance is defined as:

Eq.(1): Wasserstein distance

Where ⟨ .,. ⟩ is the Frobenius product and E(α, β) the set of constraints. The Wasserstein distance has to be computed between the full measures α and β. Unfortunately, it has a cubical complexity in the number of data O(n^3), making it non suitable for Big Data applications. Variants of OT problem came out such as the entropic OT or Sinkhorn divergence, but it still has a square complexity. To overcome this complexity, one could rely on computing the Wasserstein distance between minibatches of the source and the target measures.

Mini Batch Wasserstein Distance

Using the minibatch strategy is appealing as it gives a cubical complexity in the minibatch size O(m^3). However, optimal transport’s primal formulation is not a sum and using minibatches is not equivalent to Eq.(1). Indeed, it does not compute the Wasserstein distance but the expectation of Wasserstein distance over minibatches sampled from input measures. Formally, it computes:

Eq. (2) : expectation of Wasserstein distance over batches

Where m is the batch size. As it is not equivalent to the original problem, it is interesting to understand this new loss. We will review the consequences over the transportation plan, the asymptotic statistical properties and finally, gradient properties for first order optimization methods.

Estimated minibatch Wasserstein

Let us first design an estimator. Eq. (2) can be estimated with the following estimator:

Eq.(3): estimator of Eq. (2)

Where the sum is taken over all possible minibatch measures A and B in the source and target measures. However there are too many minibatch terms to compute. Fortunately, we can rely on a subsample quantity. We note D_k, a set of cardinality k whose elements are minibatch couples drawn uniformly. We define:

Eq.(4) : incomplete estimator of Eq. (2)

Where k is the number of minibatch couples. We can also compute a similar estimator for the transportation plan in order to estimate the effects of the minibatch paradigm on the transportation plan, and more specifically, the connections between samples (full details of the construction in the paper). The idea is to average the connections between samples to get the averaged transportation plan.

Eq.(5) : minibatch Wasserstein plan estimator

Of course, an incomplete estimator exists and follows the same construction as Eq. (4):

Eq.(6) : minibatch Wasserstein plan incomplete estimator

As now we can estimate our quantities, we will give a small but useful example.

1D case example : Minibatch Sliced Wasserstein

The 1D case is a particular case of interest. It is interesting because we have access to a close form of the Wasserstein distance when data lie in 1D and then, we can compute the OT plan easily. 1D case is also at the foundation of a widely used distance, the Sliced Wasserstein Distance. The formula to compute the full minibatch OT plan (Eq. (5)) can be found in the paper.

We consider 20 data in the source and the target domains with uniform weights and we plot the averaged transportation plan for several OT scenarios. The experiment show the difference between the minibatch OT plan for different batch size m and the OT plan of regularized variants (entropic + L2).

Transportation plan between different OT problems for 1D measures [5]

We see a similar effect between the minibatch Wasserstein distance and the regularized Wasserstein variants. We get non optimal connections between samples. For the minibatch Wasserstein distance, the number of connections increases when m decreases. It is similar to the entropic OT variant when the regularization coefficient gets bigger. One can also note that the highest intensity of connections decreases when the batch size decrease, which is due to the constraints. Now that we saw the effect of minibatches on the transportation plan, let us review the properties of our loss.

Basic properties

As we have a new loss function, it is necessary to review its strengths and weaknesses to compare probability distributions. It has the following properties:

  • For iid data, U and Ũ are unbiased estimator of Eq.(2)
  • U and Ũ are symmetric in their arguments
  • U and Ũ are strictly positive
  • U(α,α) and Ũ(α,α) are strictly positive

The interesting property here is the last one. For non trivial measures, we break the separability distance axiom. Hence, the minibatch Wasserstein distance IS NOT a distance. It is the price to pay to gain numerical speed. We will highlight this effect in a gradient flow experiment.

As we do not know the distribution α and β, we want to know if Eq.(2) can efficiently be estimated with Eq.(4).

Statistical properties

Our incomplete estimator defines an incomplete two sample U-statistic. U-statistics have been developed by Hoeffding back in the 60s [6]. Using Hoeffding inequality, it is possible to get a deviation bound of our estimator around its expectation , with probability 1-δ:

Where M is the size of the support of α and β. This deviation bound shows that if we increase the number of data n and batches k while keeping the minibatch size fixed, the error converges to 0 exponentially fast. Remarkably, the bound does not depend on the dimension of data, which is an appealing property when optimizing in high dimension.

For generative modelling we found that a small k was enough to get meaningful results but that using a small batch size resulted in a longer training time.

A similar property hold for the transportation plan and the marginales. The deviation between the marginales of the estimated transportation plan and 1/n is with probability 1-δ:

Distance to marginales

As we know that our loss has great statistical properties, we know want to study if we can minimize it with modern optimization framework.

Unbiased gradients for SGD

It is well known that the empirical Wasserstein distance has biased gradients with respect to the Wasserstein distance between continuous measures [7]. This bias makes that minimizing the empirical Wasserstein distance does not lead to the minimum of the Wasserstein distance between continuous measures.

Unlike the Wasserstein distance, Minibatch Wasserstein has the nice property to have unbaised gradients, hence we can use SGD with our incomplete estimator to minimize the loss between continuous measures.

This result was proved for entropic OT variants. Entropic OT is differentiable everywhere unlike the Wasserstein distance. This is a fundamental element of our proof. It allows us to use unbiased estimator and the differentiation lemma to prove unbiased gradients. However, we did not meet any problem while using the minibatch Wasserstein distance in practice.

Generative Models

We illustrate the use of minibatch Wasserstein loss for generative modelling. The goal is to learn a generative model to generate data close to the target data. We draw 8000 points which follow 8 different gaussian modes (1000 points per mode) in 2D where the modes form a circle. After generating the data, we use the minibatch Wasserstein distance and the minibatch Sinkhorn divergence as loss functions with a squared euclidian cost and compared them to WGAN [1] and its variant with gradient penalty WGAN-GP [8].

Gaussian modes generation [5]

We see that we are enable to generate data following the different modes. More extensive results using minibatch Sinkhorn Divergence on the CIFAR 10 datasets are available in [3].

Gradient Flow

In this section, we describe the use of a minibatch Wasserstein gradient flow. We consider 5000 male and 5000 female images from the Celeb-A dataset and want to apply a gradient flow between male and female images. The purpose of gradient flows is to model a distribution which at each iteration follows the gradient direction which minimizes the minibatch Wasserstein distance. Formally we integer the following equation:

For this experiment, we set the batch size m to 500 and the number of batch couples k to 10.

Gradient flow experiments between 5000 male images and 5000 females images [5]

We see a natural evolution in the images along the gradient flow. However the final result is a bit blurred. It is due to the fact that the minibatch Wasserstein distance is not a distance and that we do not match the target distribution.

Large scale color transfer

We also use our method for color transert experiments. The purpose of color transfer is to transform the color of a source image so that it follows the color of a target image. Optimal Transport is a well known method to solve this problem [9]. The transportation plan between the two point cloud images gives a transfer color mapping by using a barycentric projection. The idea is to use the developed minibatch transportation plan estimator that enjoys both smaller memory and calculation costs. We used two images of 1M pixels each, several batch sizes and number of batches. To the best of our knowledge, it is the first time that a method is able to handle large scale color transfert.

Large scale color transfert [5]

We can see that the diversity of colors falls when the batch size is too small as the entropic solver would do for a large regularization parameter. It is obviously due to the high number of connections between source and target samples. However, even for 1M pixels, a batch size of 1000 is enough to keep a good diversity of colors.

Conclusion

In this article, we describe the minibatch Wasserstein distance. A Wasserstein distance variant where one aims at computing the primal Wasserstein distance over minibatches. We described a formalism, the basic properties of this loss function, the asymptotic properties and finally, the optimization procedure. Then we studied its use over three different experiments. Many questions remain about how we can improve the minibatch Wasserstein distance and we will detail them in a future blog post.

Bibliography

[1] Martin Arjovsky, Soumith Chintala, Léon Bottou. Wasserstein GAN, ICML 2017

[2] BB Damodaran, B Kellenberger, R Flamary, D Tuia, N Courty, “DeepJDOT:Deep Joint Distribution Optimal Transport for Unsupervised Domain Adaptation”, in ECCV 2018.

[3] Aude Genevay, Gabriel Peyre, Marco Cuturi. Learning generative models with sinkhorn divergences, AISTATS 2018

[4] Gabriel Peyre, Marco Cuturi. Computational Optimal Transport, Foundation and Trends

[5] Kilian Fatras, Younes Zine, Rémi Flamary, Rémi Gribonval, Nicolas Courty. Learning with minibatch Wasserstein : asymptotic and gradient properties, AISTATS 2020

[6] Wassily Hoeffding. Probability Inequalities for Sums of Bounded Random Variables, Journal of the American Statistical Association 1963

[7] Marc G. Bellemare, Ivo Danihelka, Will Dabney, Shakir Mohamed, Balaji Lakshminarayanan, Stephan Hoyer, Rémi Munos. The cramer distance as a solution to biased wasserstein gradients.

[8] Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, Aaron Courville. Improved Training of Wasserstein GANs, NIPS 2017

[9] Sira Ferradans, Nicolas Papadakis, Julien Rabin, Gabriel Peyré, Jean-François Aujol. Regularized discrete optimal transport. Scale Space and Variational Methods in Computer Vision.

--

--