Hands-on guide to Python Optimal Transport toolbox: Part 1
First steps with Optimal Transport
As a follow-up of the introductory article on optimal transport by Ievgen Redko, I will present below how you can solve Optimal Transport (OT) in practice using the Python Optimal Transport (POT) toolbox.
To start with, let us install POT using pip from the terminal by simply running
pip3 install pot
Or with conda
conda install -c conda-forge pot
If everything went well, you now have POT installed and ready to use on your computer.
POT Python Optimal Transport Toolbox
Import the toolbox
import numpy as np # always need it
import scipy as sp # often use it
import pylab as pl # do the plotsimport ot # ot
Getting help
The online documentation of POT is available at http://pot.readthedocs.io, or you can check the inline help help(ot.dist)
.
We are now ready to start our example.
Simple OT Problem
We will solve the Bakery/Cafés problem of transporting croissants from a number of Bakeries to Cafés in a City (in this case Manhattan). We did a quick google map search in Manhattan for bakeries and Cafés:
We extracted from this search their positions and generated fictional production and sale number (that both sum to the same value).
We have access to the position of Bakeries bakery_pos
and their respective production bakery_prod
which describe the source distribution. The Cafés where the croissants are sold are defined also by their position cafe_pos
and cafe_prod
, and describe the target distribution.
Now we load the data
data = np.load('https://github.com/PythonOT/POT/raw/master/data/manhattan.npz')bakery_pos = data['bakery_pos']
bakery_prod = data['bakery_prod']
cafe_pos = data['cafe_pos']
cafe_prod = data['cafe_prod']
Imap = data['Imap']print('Bakery production: {}'.format(bakery_prod))
print('Cafe sale: {}'.format(cafe_prod))
print('Total croissants : {}'.format(cafe_prod.sum()))
This gives:
Bakery production: [31. 48. 82. 30. 40. 48. 89. 73.]
Cafe sale: [82. 88. 92. 88. 91.]
Total croissants : 441.0
Plotting bakeries in the city
Next, we plot the position of the bakeries and cafés on the map. The size of the circle is proportional to their production.
Cost matrix
We can now compute a cost matrix between the bakeries and the cafés, which will be the transport cost matrix. This can be done using the ot.dist
function that defaults to squared Euclidean distance but can return other things such as cityblock (or Manhattan distance).
M = ot.dist(bakery_pos, cafe_pos)
The red cells in the matrix image show the bakeries and cafés that are further away, and thus more costly to transport from one to the other, while the blue ones show those that are very close to each other, with respect to the squared Euclidean distance.
Solving the OT problem with Earth Mover’s distance
We now come to the problem itself, which is to find an optimal solution to the problem of transporting croissants from bakeries to cafés. In order to do that, let’s see a little bit of maths.
The aim is to find the transport matrix gamma
such that
where M is the cost matrix, and a and b are respectively the sample weights for source and target.
So, what it means, is that we take into account the cost of transporting croissants from a bakery to a café through M, and we want the sum of each line of
gamma
to be the number of croissants the corresponding bakery has to sell, and the sum of each column to be the number of croissants the corresponding cafés needs. Hence, each element of the transport matrix will correspond to the number of croissants that a bakery has to send to a café.
This problem is called Earth Mover’s distance, or EMD, also known as discrete Wasserstein distance.
Let’s see what it gives on our example.
gamma_emd = ot.emd(bakery_prod, cafe_prod, M)
The graph below (left) show the transport from a bakery to a café, with the width of the line proportional to the number of croissants to be transported. On the right, we can see the transport matrix with the exact values. We can see that the bakeries only need to transport croissants to one or two cafés, the transport matrix being very sparse.
Regularized OT with Sinkhorn
One issue with EMD is that its algorithmic complexity is in O(n³log(n)), n being the largest dimension between source and target. In our example, n is small, so it is OK to use EMD, but for larger values of n we might want to look into other options.
As is often the case when an algorithm is long to compute, we can regularize it in order to obtain a solution to a simpler or faster problem. The Sinkhorn algorithm does that by adding an entropic regularization term and thus solves the following problem.
where reg is an hyperparameter and Omega is the entropic regularization term defined by:
The Sinkhorn algorithm is very simple to code. You can implement it directly using the following pseudo-code:
Be careful of numerical problems. A good pre-processing for Sinkhorn is to divide the cost matrix M
by its maximum value.
reg = 0.1
K = np.exp(-M / M.max() / reg)
nit = 100
u = np.ones((len(bakery_prod), ))
for i in range(1, nit):
v = cafe_prod / np.dot(K.T, u)
u = bakery_prod / (np.dot(K, v))
gamma_sink_algo = np.atleast_2d(u).T * (K * v.T) # Equivalent to np.dot(np.diag(u), np.dot(K, np.diag(v)))
An alternative is to use the POT toolbox with ot.sinkhorn:
gamma_sinkhorn = ot.sinkhorn(bakery_prod, cafe_prod, reg=reg, M=M/M.max())
When plotting the resulting transport matrix, we notice right away that it is not sparse at all with Sinkhorn, each bakery delivering croissants to all 5 cafés with that solution. Also, this solution gives a transport with fractions, which does not make sense in the case of croissants. This was not the case with EMD.
Varying the regularization parameter in Sinkhorn
Obviously, the regularization hyperparameter reg of Sinkhorn plays an important role. Let’s see with the following graphs how it impacts the transport matrix by looking at different values.
This series of graphs shows that the solution of Sinkhorn starts with something very similar to EMD (although not sparse) for very small values of the regularization parameter reg
, and tends to a more uniform solution as reg
increases.
Conclusion
This first part showed a simple example for applying Optimal Transport with POT library. Optimal Transport is a powerful tool that can be applied in many ways, as discussed in Part 2 by Ievgen Redko.