Skip to content

Instantly share code, notes, and snippets.

@jcowles
Last active May 13, 2021 19:21
Show Gist options
  • Save jcowles/1de8b12c38603ce932b0154bc6d59d60 to your computer and use it in GitHub Desktop.
Save jcowles/1de8b12c38603ce932b0154bc6d59d60 to your computer and use it in GitHub Desktop.
Convolutional Wasserstein Distance & Barycenter implementation in PyTorch
# CC0, 2021 Jeremy Cowles, no rights reserved.
#
# The following is a 1-dimensional implementation of two core algorithms from the paper
#
# Convolutional Wasserstein Distances: Efficient Optimal Transportation on Geometric Domains
# https://people.csail.mit.edu/jsolomon/assets/convolutional_w2.compressed.pdf
#
import torch
from torch import nn
import torch.nn.functional as F
import math
# The following sets up a 1D gaussian, but can be extended to multiple dimensions
# by modifying _create_window().
def _gauss(window_size, sigma) -> torch.Tensor:
"""
Create weights for a discrete gaussian kernel.
"""
gauss = torch.tensor(
[math.exp(-(x - 0.5 * (window_size - 1)) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)],
dtype=torch.float,
requires_grad=False,
)
return gauss / gauss.sum()
def _create_window(window_size, sigma = None) -> torch.Tensor:
"""
Create 1D weights for a gaussian convolution.
If sigma is not provided, it will be computed from the window size.
"""
if not sigma:
# Compute sigma in terms of pixels.
# Auto-sigma is great for image-based convolutions, but makes less sense in this context,
# since sigma directly controls the error of the wasserstein approximations.
#
# Expect this default sigma value to be blurry.
sigmas_per_pixel = 2.5
sigma = 0.5 * (window_size - 1) / sigmas_per_pixel
window = _gauss(window_size, sigma).unsqueeze(1)
# 2D window could be constructed here.
return window, sigma
def wasserstein_distance(mu0, mu1):
"""
mu0, mu1: The source and target distributions
Returns: Convolutional wasserstein distance between the two distributions.
"""
# Convolutional window size.
k = 5
# Smaller sigma = more accurate distance, but requires more iterations.
Hw,sigma = _create_window(k, sigma=0.2)
Hw = Hw.reshape([1,1,Hw.shape[0]])
H = lambda x: F.conv1d(x, Hw, padding=k // 2)[0]
gamma = sigma * sigma
a = 1.0 / mu0.shape[-1]
v = torch.ones(mu0.shape)
w = torch.ones(mu0.shape)
for i in range(40):
v = mu0 / H(a * w)
w = mu1 / H(a * v)
return gamma * (a * (mu0 * v.log() + mu1 * w.log())).sum()
def wasserstein_barycenter(mu_s, weights):
"""
mu_s: The distribution endpoints
weights: The weights for each distribution
Returns: a new distribution, an interpolation of the endpoints according to given weights.
"""
assert len(mu_s) == len(weights)
# Convolutional window size.
k = 5
# Smaller sigma = shaper interpolation, but requires more iterations.
Hw,sigma = _create_window(k, sigma=0.1)
Hw = Hw.reshape([1,1,Hw.shape[0]])
H = lambda x: F.conv1d(x.unsqueeze(0), Hw, padding=k // 2)[0]
gamma = sigma * sigma
shape = mu_s.shape
a = 1.0 / shape[-1]
v = torch.ones(mu_s.shape)
w = torch.ones(mu_s.shape)
d = torch.ones(mu_s.shape)
for i in range(5):
mu = torch.ones(1, 1, shape[-1])
# Constraint 1: pi_i marginalizes to mu_i in one direction
for i in range(len(weights)):
w[i] = mu_s[i] / H(a * v[i])
d[i] = v[i] * H(a * w[i])
mu = mu * torch.pow(d[i], weights[i])
# NOTE: Entropic sharpening not implemented.
# Constraint 2: all pi_s marginalize to the same mu in the other direction
for i in range(len(weights)):
v[i] = v[i] * mu / d[i]
return mu
def plot(wd, lin, init, delay):
import matplotlib.pyplot as plt
if init:
plt.ion()
plt.show()
plt.clf()
plt.title("Wasserstein vs. Linear Interpolation")
colors = {'Wasserstein':'royalblue', 'Linear':'darkorange'}
labels = list(colors.keys())
handles = [plt.Rectangle((0,0),1,1, color=colors[label]) for label in labels]
plt.legend(handles, labels)
plt.ylim(0,1)
plt.bar(range(1,17,2),height=wd.squeeze(),color=colors["Wasserstein"])
plt.bar(range(2,17,2),height=lin.squeeze(), color=colors["Linear"])
plt.pause(delay)
# Compute wasserstein distance between two distributions.
wd = wasserstein_distance(torch.tensor([[[.50, .25, .01, .01, .01, .01, .20, .01]]]),
torch.tensor([[[.20, .01, .01, .01, .01, .74, .01, .01]]]))
print("Wasserstein Distance:", wd)
# Interpolate between two distributions, the second tensor should be barycentric weights.
endpoints = torch.tensor([
[[.50, .25, .01, .01, .01, .01, .20, .01]],
[[.20, .01, .01, .01, .01, .74, .01, .01]]
])
# Wasserstein interpolation.
steps = 50
while True:
init = True
for i in range(steps):
v = i / (steps-1)
weights = torch.tensor([1-v, v])
mu = wasserstein_barycenter(endpoints, weights)
linear = endpoints[0] * weights[0] + endpoints[1] * weights[1]
plot(mu, linear, init, (1.0/steps) * 2)
init = False
print(mu)
@jcowles
Copy link
Author

jcowles commented Feb 8, 2021

interp-fixed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment