Skip to content

Instantly share code, notes, and snippets.

@vene
Created April 6, 2020 19:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vene/882e5221031c054e43d8a07e4e54de63 to your computer and use it in GitHub Desktop.
Save vene/882e5221031c054e43d8a07e4e54de63 to your computer and use it in GitHub Desktop.
barycenter on sphere
# author: vlad niculae <vlad@vene.ro>
# license: mit
import numpy as np
from scipy.special import softmax
from mayavi import mlab
def barycenter(X, w):
# riemannian opt
def s_exp(x, p):
nrm = np.linalg.norm(p)
return np.cos(nrm) * x + np.sin(nrm) * (p / nrm)
def s_log(x, y):
cos = np.dot(x, y)
dist = np.arccos(cos)
p = y - cos * x
p /= np.linalg.norm(p)
return dist * p
def obj(xbar):
dists = np.arccos(np.dot(X, xbar)) ** 2
return 0.5 * np.dot(w, dists)
def grad(xbar):
return -sum(w[k] * s_log(xbar, X[k]) for k in range(len(X)))
# mean and project
xbar_proj = np.dot(X.T, w)
xbar_proj /= np.linalg.norm(xbar_proj)
# xbar = np.array([0, 0, 1.0])
xbar = xbar_proj
xbars = [xbar]
n_iter = 50
lr = 0.1
for t in range(n_iter):
print(obj(xbar), xbar)
xbar = s_exp(xbar, -lr * grad(xbar))
xbars.append(xbar)
print("compared to")
print(obj(xbar_proj), xbar_proj)
return xbar_proj, xbars
def main():
X = np.array([
[1, 0, 0],
[-1, 0, 0],
[1, 1, 1],
[1, -2, 1]], dtype=np.double)
X /= np.linalg.norm(X, axis=1)[:, np.newaxis]
# Create a sphere
r = 1.0
pi = np.pi
cos = np.cos
sin = np.sin
phi, theta = np.mgrid[0:pi:101j, 0:2 * pi:101j]
x = r*sin(phi)*cos(theta)
y = r*sin(phi)*sin(theta)
z = r*cos(phi)
mlab.figure(1, bgcolor=(1, 1, 1), fgcolor=(0, 0, 0), size=(400, 300))
mlab.clf()
mlab.mesh(x , y , z, color=(0.0,0.5,0.5))
mlab.points3d(X[:, 0], X[:, 1], X[:, 2], scale_factor=0.1)
w = np.zeros(X.shape[0])
# w[0] += 1
w = softmax(w)
xbar_proj, xbars = barycenter(X, w)
for xbar in xbars:
mlab.points3d(xbar[0], xbar[1], xbar[2], color=(0, 1, 0),
scale_factor=.1)
mlab.points3d(xbar_proj[0], xbar_proj[1], xbar_proj[2], color=(1, 0, 0),
scale_factor=.1)
mlab.show()
def check_high_dim():
X = np.random.randn(10, 50)
X /= np.linalg.norm(X, axis=1)[:, np.newaxis]
w = np.zeros(X.shape[0])
w = softmax(w)
xbar_proj, xbars = barycenter(X, w)
if __name__ == '__main__':
# check_high_dim()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment