Skip to content

Instantly share code, notes, and snippets.

@hichamjanati
Created February 9, 2020 17:34
Show Gist options
  • Save hichamjanati/fbcc096e9c5763c8d64091f2d78fe8d3 to your computer and use it in GitHub Desktop.
Save hichamjanati/fbcc096e9c5763c8d64091f2d78fe8d3 to your computer and use it in GitHub Desktop.
truncated cost in Sinkhorn
import numpy as np
import warnings
from scipy import sparse
from scipy.stats import norm as gaussian
def make_gaussians(grid, n_hists, loc=None, scale=None, normed=True,
mass=None):
"""Generate random gaussian histograms.
"""
if loc is None:
loc = np.zeros(n_hists)
if scale is None:
scale = np.ones(n_hists)
if mass is None:
mass = np.ones(n_hists)
coefs = np.empty((len(grid), n_hists))
for i, (l, s) in enumerate(zip(loc, scale)):
coefs[:, i] = gaussian.pdf(grid, loc=l, scale=s)
if normed:
coefs /= coefs.sum(axis=0)
coefs *= mass
return coefs
def sinkhorn(p, q, K, epsilon, maxiter=5000, tol=1e-8):
"""Compute the Wasserstein divergence between histograms.
"""
err = 10
log = dict(err=[])
b = np.ones_like(p)
Kb = K.dot(b)
for i in range(maxiter):
a = p / Kb
Ka = K.T.dot(a)
b = q / Ka
Kb = K.dot(b)
err = abs(p - a * Kb).max()
log["err"].append(err)
if err < tol:
break
if np.isnan(Kb).any():
warnings.warn("Numerical Errors ! Stopped at last stable "
"iteration.")
break
if i == maxiter - 1:
warnings.warn("*** Maxiter reached ! err = {} ***".format(err))
log['flag'] = 3
f = (np.log(a + 1e-100) * p).sum()
f += (np.log(b + 1e-100) * q).sum()
f *= epsilon
return f, i
if __name__ == "__main__":
from matplotlib import pyplot as plt
from time import time
epsilon = 1.
n_dim = 5
dimensions = np.linspace(10, 100, n_dim).astype(int) * 100
means = np.array([-0.05, 0.05])
sigma = 0.4
std = sigma * np.ones(2)
# cut-off for the squared ground metric
thresholds = np.array([0.1, 0.05, 0.002])
sparsity = []
n_thresh = len(thresholds)
times, times_sparse, loss, loss_sparse = np.zeros((4, n_thresh, n_dim))
times_np, loss_np = np.zeros((2, n_thresh, n_dim))
for ii, threshold in enumerate(thresholds):
print(">> Doing threshold %s / %s" % (ii + 1, n_thresh))
for jj, n_features in enumerate(dimensions):
print("Doing n_features %s / %s" % (jj + 1, n_dim))
grid = np.linspace(-5., 5., n_features)
M = (grid[:, None] - grid[None, :]) ** 2
Ms = M.copy()
K = np.exp(- M / epsilon)
Ms[Ms > threshold * M.max()] = float("inf")
Ks = np.exp(- Ms / epsilon)
Ks = sparse.csr_matrix(Ks)
x, y = make_gaussians(grid, 2, loc=means, scale=std).T
t = time()
f, i = sinkhorn(x, y, K, epsilon)
t = time() - t
times[ii, jj] = t
loss[ii, jj] = f
print("niter of full M", i)
ts = time()
fs, i = sinkhorn(x, y, Ks, epsilon)
ts = time() - ts
times_sparse[ii, jj] = ts
loss_sparse[ii, jj] = fs
print("niter of truncated-sparse", i)
ts = time()
fs, i = sinkhorn(x, y, Ks.toarray(), epsilon)
ts = time() - ts
times_np[ii, jj] = ts
loss_np[ii, jj] = fs
print("niter of truncated-numpy", i)
sparsity.append((Ks > 0).mean())
titles = ["Loss", "Time"]
labels = ["Full M", "Truncated M (scipy.sparse)",
"Truncated M (numpy)"]
f, axes = plt.subplots(2, n_thresh, figsize=(18, 6), sharex=True)
all_data = [[loss, loss_sparse, loss_np], [times, times_sparse, times_np]]
for ii, (ax_group, title, data) in enumerate(zip(axes, titles, all_data)):
tmp = np.stack(data)
tmp = np.swapaxes(tmp, 0, 1)
for ax, data, sp in zip(ax_group.flatten(), tmp, sparsity):
for data, label in zip(data, labels):
ax.plot(dimensions, data, lw=2, label=label)
ax.set_xlabel("# bins (dimension)")
ax.set_ylabel(title)
if ii == 0:
ax.set_title("nnz = %.2f %%" % (sp * 100))
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment