Created
February 9, 2020 17:34
-
-
Save hichamjanati/fbcc096e9c5763c8d64091f2d78fe8d3 to your computer and use it in GitHub Desktop.
truncated cost in Sinkhorn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import warnings | |
from scipy import sparse | |
from scipy.stats import norm as gaussian | |
def make_gaussians(grid, n_hists, loc=None, scale=None, normed=True, | |
mass=None): | |
"""Generate random gaussian histograms. | |
""" | |
if loc is None: | |
loc = np.zeros(n_hists) | |
if scale is None: | |
scale = np.ones(n_hists) | |
if mass is None: | |
mass = np.ones(n_hists) | |
coefs = np.empty((len(grid), n_hists)) | |
for i, (l, s) in enumerate(zip(loc, scale)): | |
coefs[:, i] = gaussian.pdf(grid, loc=l, scale=s) | |
if normed: | |
coefs /= coefs.sum(axis=0) | |
coefs *= mass | |
return coefs | |
def sinkhorn(p, q, K, epsilon, maxiter=5000, tol=1e-8): | |
"""Compute the Wasserstein divergence between histograms. | |
""" | |
err = 10 | |
log = dict(err=[]) | |
b = np.ones_like(p) | |
Kb = K.dot(b) | |
for i in range(maxiter): | |
a = p / Kb | |
Ka = K.T.dot(a) | |
b = q / Ka | |
Kb = K.dot(b) | |
err = abs(p - a * Kb).max() | |
log["err"].append(err) | |
if err < tol: | |
break | |
if np.isnan(Kb).any(): | |
warnings.warn("Numerical Errors ! Stopped at last stable " | |
"iteration.") | |
break | |
if i == maxiter - 1: | |
warnings.warn("*** Maxiter reached ! err = {} ***".format(err)) | |
log['flag'] = 3 | |
f = (np.log(a + 1e-100) * p).sum() | |
f += (np.log(b + 1e-100) * q).sum() | |
f *= epsilon | |
return f, i | |
if __name__ == "__main__": | |
from matplotlib import pyplot as plt | |
from time import time | |
epsilon = 1. | |
n_dim = 5 | |
dimensions = np.linspace(10, 100, n_dim).astype(int) * 100 | |
means = np.array([-0.05, 0.05]) | |
sigma = 0.4 | |
std = sigma * np.ones(2) | |
# cut-off for the squared ground metric | |
thresholds = np.array([0.1, 0.05, 0.002]) | |
sparsity = [] | |
n_thresh = len(thresholds) | |
times, times_sparse, loss, loss_sparse = np.zeros((4, n_thresh, n_dim)) | |
times_np, loss_np = np.zeros((2, n_thresh, n_dim)) | |
for ii, threshold in enumerate(thresholds): | |
print(">> Doing threshold %s / %s" % (ii + 1, n_thresh)) | |
for jj, n_features in enumerate(dimensions): | |
print("Doing n_features %s / %s" % (jj + 1, n_dim)) | |
grid = np.linspace(-5., 5., n_features) | |
M = (grid[:, None] - grid[None, :]) ** 2 | |
Ms = M.copy() | |
K = np.exp(- M / epsilon) | |
Ms[Ms > threshold * M.max()] = float("inf") | |
Ks = np.exp(- Ms / epsilon) | |
Ks = sparse.csr_matrix(Ks) | |
x, y = make_gaussians(grid, 2, loc=means, scale=std).T | |
t = time() | |
f, i = sinkhorn(x, y, K, epsilon) | |
t = time() - t | |
times[ii, jj] = t | |
loss[ii, jj] = f | |
print("niter of full M", i) | |
ts = time() | |
fs, i = sinkhorn(x, y, Ks, epsilon) | |
ts = time() - ts | |
times_sparse[ii, jj] = ts | |
loss_sparse[ii, jj] = fs | |
print("niter of truncated-sparse", i) | |
ts = time() | |
fs, i = sinkhorn(x, y, Ks.toarray(), epsilon) | |
ts = time() - ts | |
times_np[ii, jj] = ts | |
loss_np[ii, jj] = fs | |
print("niter of truncated-numpy", i) | |
sparsity.append((Ks > 0).mean()) | |
titles = ["Loss", "Time"] | |
labels = ["Full M", "Truncated M (scipy.sparse)", | |
"Truncated M (numpy)"] | |
f, axes = plt.subplots(2, n_thresh, figsize=(18, 6), sharex=True) | |
all_data = [[loss, loss_sparse, loss_np], [times, times_sparse, times_np]] | |
for ii, (ax_group, title, data) in enumerate(zip(axes, titles, all_data)): | |
tmp = np.stack(data) | |
tmp = np.swapaxes(tmp, 0, 1) | |
for ax, data, sp in zip(ax_group.flatten(), tmp, sparsity): | |
for data, label in zip(data, labels): | |
ax.plot(dimensions, data, lw=2, label=label) | |
ax.set_xlabel("# bins (dimension)") | |
ax.set_ylabel(title) | |
if ii == 0: | |
ax.set_title("nnz = %.2f %%" % (sp * 100)) | |
plt.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment