Skip to content

Instantly share code, notes, and snippets.

@vene
Created November 3, 2021 14:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vene/c62262904355b7d39356b95c829d4745 to your computer and use it in GitHub Desktop.
Save vene/c62262904355b7d39356b95c829d4745 to your computer and use it in GitHub Desktop.
"""
Approximating the cross-entropy between two Power Sphericals.
Uses a second-order Taylor expansion to approximate E[log(1+z)].
"""
# author: vlad n <vlad@vene.ro>
# license: mit
# documentation: https://hackmd.io/@vladn/SJ93wMevK
import numpy as np
import torch
from power_spherical import PowerSpherical, HypersphericalUniform
def ps_variance_dot(ps, y):
"""Computes the variance of dot(x, y) for x~ps and norm(y) = 1."""
alpha = ps.base_dist.marginal_t.base_dist.concentration1
beta = ps.base_dist.marginal_t.base_dist.concentration0
ratio = (alpha + beta) / (2 * beta)
t_var = ps.base_dist.marginal_t.variance
dp = ps.loc @ y # check dimension
yy = 1 # yy = y @ y, but we know this to be 1.
return t_var * ((1 - ratio) * dp ** 2 + ratio * yy) # + mean_sq - mean_sq
def check_ps_variance_dot(
dim=10,
k=20,
n_samples=1000):
dim = torch.tensor(dim)
k = torch.tensor(k)
unif = HypersphericalUniform(dim=dim)
mu_p = unif.rsample()
mu_q = unif.rsample()
p = PowerSpherical(loc=mu_p, scale=k)
xp = p.rsample((n_samples,))
z = xp @ mu_q
# true mean
z_mean = p.mean @ mu_q
print("mean z true:", z_mean.item())
print("mean z num: ", torch.mean(z).item())
print("V[z] tru: ", ps_variance_dot(p, mu_q).item())
print("V[z] num: ", torch.var(z).item())
def check_taylor(
dim=10,
k=3,
n_samples=10000):
dim = torch.tensor(dim)
k = torch.tensor(k)
unif = HypersphericalUniform(dim=dim)
mu_p = unif.rsample()
mu_q = unif.rsample()
p = PowerSpherical(loc=mu_p, scale=k)
xp = p.rsample((n_samples,))
# approximate E[ log(1+z) ] where z = dot(x,y), x~ps
z = xp @ mu_q
z_mean = p.mean @ mu_q
taylor_first = torch.log1p(z_mean)
taylor_second = ps_variance_dot(p, mu_q) / (2 * (1 + z_mean) ** 2)
taylor = taylor_first - taylor_second
print("MC: ", torch.mean(torch.log1p(z)).item())
print("Tay: ", taylor.item())
def main():
check_ps_variance_dot()
check_taylor()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment