Skip to content

Instantly share code, notes, and snippets.

@ctralie
Created November 14, 2018 01:43
Show Gist options
  • Save ctralie/66352ae6ab06c009f02c705385a446f3 to your computer and use it in GitHub Desktop.
Save ctralie/66352ae6ab06c009f02c705385a446f3 to your computer and use it in GitHub Desktop.
2D Histogram Wasserstein Distance via POT Library
"""
Programmer: Chris Tralie
Purpose: To use the POT library (https://github.com/rflamary/POT)
to compute the Entropic regularized Wasserstein distance
between points on a 2D grid
"""
import numpy as np
import matplotlib.pyplot as plt
import ot
def testMovingDisc():
"""
Show optimal transport on a moving disc in a 50x50 grid
"""
## Step 1: Setup problem
pix = np.linspace(-1, 1, 80)
# Setup grid
X, Y = np.meshgrid(pix, pix)
# Compute pariwise distances between points on 2D grid so we know
# how to score the Wasserstein distance
coords = np.array([X.flatten(), Y.flatten()]).T
coordsSqr = np.sum(coords**2, 1)
M = coordsSqr[:, None] + coordsSqr[None, :] - 2*coords.dot(coords.T)
M[M < 0] = 0
M = np.sqrt(M)
ts = np.linspace(-0.8, 0.8, 100)
## Step 2: Compute L2 distances and Wasserstein
Images = []
radius = 0.2
L2Dists = [0.0]
WassDists = [0.0]
for i, t in enumerate(ts):
I = 1e-5 + np.array((X-t)**2 + (Y-t)**2 < radius**2, dtype=float)
I /= np.sum(I)
Images.append(I)
if i > 0:
L2Dists.append(np.sqrt(np.sum((I-Images[0])**2)))
wass = ot.sinkhorn2(Images[0].flatten(), I.flatten(), M, 1.0)
print(wass)
WassDists.append(wass)
## Step 3: Make Animation
L2Dist = np.array(L2Dists)
WassDists = np.array(WassDists)
I0 = Images[0]
plt.figure(figsize=(15, 5))
displacements = np.sqrt(2)*(ts - ts[0])
for i, I in enumerate(Images):
plt.clf()
D = np.concatenate((I0[:, :, None], I[:, :, None], 0*I[:, :, None]), 2)
D = D*255/np.max(I0)
D = np.array(D, dtype=np.uint8)
plt.subplot(131)
plt.imshow(D, extent = (pix[0], pix[-1], pix[-1], pix[0]))
plt.subplot(132)
plt.plot(displacements, L2Dists)
plt.stem([displacements[i]], [L2Dists[i]])
plt.xlabel("Displacements")
plt.ylabel("L2 Dist")
plt.title("L2 Dist")
plt.subplot(133)
plt.plot(displacements, WassDists)
plt.stem([displacements[i]], [WassDists[i]])
plt.xlabel("Displacements")
plt.ylabel("Wasserstein Dist")
plt.title("Wasserstein Dist")
plt.savefig("%i.png"%i, bbox_inches='tight')
if __name__ == '__main__':
testMovingDisc()
@ctralie
Copy link
Author

ctralie commented Mar 23, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment