Skip to content

Instantly share code, notes, and snippets.

@vene
Last active November 9, 2021 14:06
Show Gist options
  • Save vene/73e7ee0dbefe7f91ca0af092f5aa0055 to your computer and use it in GitHub Desktop.
Save vene/73e7ee0dbefe7f91ca0af092f5aa0055 to your computer and use it in GitHub Desktop.
"""wrapped hyperbolic distributions
following https://arxiv.org/abs/1902.02992
"""
# author: vlad niculae <v.niculae@uva.nl>
# license: bsd 3-clause
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def hyperboloid_to_circles(x):
# first, move to beltrami-klein disk
x_bk = x[..., 1:] / x[..., 0].unsqueeze(-1)
# compute radii
r = torch.norm(x_bk, dim=-1).unsqueeze(-1)
x_pc = x_bk * r / (1 + torch.sqrt(1 - r**2))
return x_bk, x_pc
def lorenz_inner(x1, x2):
"""<x1, x2> for x1, x2 in H^{d} represented in R^{d+1} coords."""
prod = x1 * x2
return -prod[..., 0] + torch.sum(prod[..., 1:], dim=-1)
def lorenz_exp(u, x):
"""exp_x(u) where x in H, u in T_mu H"""
r = torch.sqrt(lorenz_inner(u, u)).unsqueeze(-1)
z = torch.cosh(r) * x + torch.sinh(r) * (u/r)
return z
def lorenz_log(z, x):
alpha = -lorenz_inner(x, z).unsqueeze(-1)
u = (torch.arccosh(alpha) / torch.sqrt(alpha**2 - 1)) * (z - alpha*x)
return u
def lorenz_exp0(u_tilde):
"""exp at origin of u = [0, u_tilde]"""
r = torch.norm(u_tilde, dim=-1).unsqueeze(dim=-1)
r_denom = torch.where(r == 0, torch.ones_like(r), r)
return torch.cat([torch.cosh(r), torch.sinh(r) * (u_tilde / r_denom)], dim=-1)
def transp(v, x_from, x_to):
alpha = -lorenz_inner(x_from, x_to).unsqueeze(-1)
# print(alpha.shape, v.shape, x_from.shape, x_to.shape)
coef = lorenz_inner(x_to - alpha*x_from, v) / (alpha + 1)
return v + coef.unsqueeze(-1) * (x_from + x_to)
def transp0(v, x_to):
""" transport from origin """
alpha = x_to[..., 0].unsqueeze(-1)
coef = (x_to[..., 1:] * v).sum(-1) / (alpha + 1)
coef = coef.unsqueeze(-1)
v_expanded = torch.cat([coef, v], dim=-1)
return v_expanded + coef * x_to
def log_density(x, loc, std):
"""p(x | loc, std)"""
d = x.shape[-1] - 1
origin = torch.zeros(d+1)
origin[0] = 1
# lift x into tangent space at loc
v_loc = lorenz_log(x, loc)
# transport to origin
v_0 = transp(v_loc, x_from=loc, x_to=origin)
# standard normal density
z = v_0[..., 1:]
r = torch.norm(z, dim=-1)
logp = -r / (2 * std ** 2)
logp -= torch.log(std * (2 * np.pi) ** (d/2))
logdet = (d-1) * torch.log(torch.sinh(r) / r)
return logp - logdet
def contours(loc, std, ax_bk, ax_pc, n=500):
x_ = np.linspace(-5 * std, 5 * std, n)
y_ = np.linspace(-5 * std, 5 * std, n)
xv, yv = np.meshgrid(x_, y_)
# into tangent space at origin
V0 = np.column_stack((xv.ravel(), yv.ravel()))
X = lorenz_exp0(torch.from_numpy(V0))
Z = log_density(X, loc, std)
z = Z.reshape(n, n)
X_bk, X_pc = hyperboloid_to_circles(X)
x_bk = X_bk[:, 0].reshape(n, n)
y_bk = X_bk[:, 1].reshape(n, n)
x_pc = X_pc[:, 0].reshape(n, n)
y_pc = X_pc[:, 1].reshape(n, n)
ax_bk.contour(x_bk, y_bk, z)
ax_pc.contour(x_pc, y_pc, z)
def main():
d = 2
n_pts = 500
std = torch.tensor(.5)
# get a random location on manifold by taking a step from the origin.
loc = torch.randn(d)
loc = lorenz_exp0(loc)
# draw gaussian in tangent space around origin:
zz = std * torch.randn(n_pts, d)
# parallel transport them to loc
# todo: special-case transport from origin
zp = transp0(zz, x_to=loc)
# exp map to surface
x = lorenz_exp(zp, loc)
# logp = log_density(x, loc, std)
fig = plt.figure(figsize=(8, 3), constrained_layout=True)
ax1 = fig.add_subplot(131, projection='3d')
ax2 = fig.add_subplot(132)
ax3 = fig.add_subplot(133)
ax1.scatter(x[:, 1], x[:, 2], x[:, 0], marker='.')
ax1.set_title("ambient space")
contours(loc, std, ax2, ax3)
x_bk, x_pc = hyperboloid_to_circles(x)
ax2.scatter(x_bk[:, 0], x_bk[:, 1], marker='.')
ax2.set_title("Beltrami-Klein disk")
ax2.add_patch(plt.Circle((0, 0), radius=1, edgecolor='k', facecolor='None'))
ax2.set_aspect("equal")
ax3.scatter(x_pc[:, 0], x_pc[:, 1], marker='.')
ax3.add_patch(plt.Circle((0, 0), radius=1, edgecolor='k', facecolor='None'))
ax3.set_title("Poincare disk")
ax3.set_aspect("equal")
plt.show()
if __name__ == '__main__':
main()
@vene
Copy link
Author

vene commented Nov 9, 2021

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment