Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Created April 26, 2021 22:51
Show Gist options
  • Save crowsonkb/08580a64b52cacf80712b2ffb99ca98a to your computer and use it in GitHub Desktop.
Save crowsonkb/08580a64b52cacf80712b2ffb99ca98a to your computer and use it in GitHub Desktop.
Spherical weighted average
import geoopt
def spherical_avg(p, w=None, tol=1e-6):
sphere = geoopt.Sphere()
if w is None:
w = p.new_ones([p.shape[0]])
assert p.ndim == 2 and w.ndim == 1 and len(p) == len(w)
w = w / w.sum()
p = sphere.projx(p)
q = sphere.projx(p.mul(w.unsqueeze(1)).sum(dim=0))
while True:
q_new = sphere.retr(q, sphere.logmap(q, p).mul(w.unsqueeze(1)).sum(dim=0))
norm = q.sub(q_new).norm()
q = q_new
if norm <= tol:
break
return q
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment