Skip to content

Instantly share code, notes, and snippets.

@ferrine
Created July 4, 2019 13:56
Show Gist options
  • Save ferrine/bdab3b2d34e1494421ac262b4be499df to your computer and use it in GitHub Desktop.
Save ferrine/bdab3b2d34e1494421ac262b4be499df to your computer and use it in GitHub Desktop.
Mobius Linear example
import torch.nn
import geoopt
# package.nn.modules.py
def create_ball(ball=None, c=None):
if ball is None:
assert c is not None, "curvature of the ball should be explicitly specified"
ball = geoopt.PoincareBall(c)
elif not isinstance(ball, geoopt.PoincareBall):
raise ValueError("ball should be an instance of PoncareMall")
return ball
class MobiusLinear(torch.nn.Linear):
def __init__(self, *args, nonlin=None, ball=None, c=1.0, **kwargs):
super().__init__(*args, **kwargs)
self.ball = create_ball(ball, c)
if self.bias is not None:
self.bias = geoopt.ManifoldParameter(self.bias, manifold=self.ball)
self.nonlin = nonlin
self.reset_parameters()
def forward(self, input):
return mobius_linear(
input,
weight=self.weight,
bias=self.bias,
nonlin=self.nonlin,
ball=self.ball,
)
@torch.no_grad()
def reset_parameters(self):
torch.nn.init.eye_(self.weight)
self.weight.add_(torch.rand_like(self.weight).mul_(1e-3))
if self.bias is not None:
self.bias.zero_()
# package.nn.functional.py
def mobius_linear(input, weight, bias=None, nonlin=None, *, ball: geoopt.PoincareBall):
output = ball.mobius_matvec(weight, input)
if bias is not None:
output = ball.mobius_add(output, bias)
if nonlin is not None:
output = ball.mobius_fn_apply(nonlin, output)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment