Skip to content

Instantly share code, notes, and snippets.

@Totoro97
Created December 10, 2021 07:06
Show Gist options
  • Save Totoro97/c6ead347ebe35faa5e308bee21f746cc to your computer and use it in GitHub Desktop.
Save Totoro97/c6ead347ebe35faa5e308bee21f746cc to your computer and use it in GitHub Desktop.
class SineLayer(nn.Module):
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
# If is_first=True, os a frequency factor which simply multiplies the activations before the
# nonlinearity. Different simega_0 signals may require different omega_0 in the first layer - this is a
# hyperparameter.
# If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
# activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
def __init__(self, in_features, out_features, bias=True,
is_first=False, omega_0=30, weight_norm=False):
super().__init__()
self.omega_0 = omega_0
self.is_first = is_first
self.in_features = in_features
self.linear = nn.Linear(in_features, out_features, bias=bias)
self.init_weights()
if weight_norm:
self.linear = nn.utils.weight_norm(self.linear)
def init_weights(self):
with torch.no_grad():
if self.is_first:
self.linear.weight.uniform_(-1. / self.in_features,
1. / self.in_features)
else:
self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
np.sqrt(6 / self.in_features) / self.omega_0)
def forward(self, input):
return torch.sin(self.omega_0 * self.linear(input))
class Siren(nn.Module):
def __init__(self,
in_features,
hidden_features,
hidden_layers,
out_features,
first_omega_0=30,
hidden_omega_0=30,
squeeze_out=False,
weight_norm=True,
skip=()):
super().__init__()
self.squeeze_out = squeeze_out
Layer = SineLayer
self.first_layer = Layer(in_features, hidden_features,
is_first=True, omega_0=first_omega_0, weight_norm=weight_norm)
self.hidden_layers = hidden_layers
self.skip = skip
for i in range(hidden_layers):
if i in skip:
setattr(self, "h_layer_{}".format(i), Layer(hidden_features + hidden_features,
hidden_features,
is_first=False,
omega_0=hidden_omega_0,
weight_norm=weight_norm))
else:
setattr(self, "h_layer_{}".format(i), Layer(hidden_features,
hidden_features,
is_first=False,
omega_0=hidden_omega_0,
weight_norm=weight_norm))
self.final_linear = nn.Linear(hidden_features, out_features)
with torch.no_grad():
self.final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
np.sqrt(6 / hidden_features) / hidden_omega_0)
if weight_norm:
self.final_linear = nn.utils.weight_norm(self.final_linear)
def forward(self, coords):
input = self.first_layer(coords)
x = input
for i in range(self.hidden_layers):
if i in self.skip:
x = torch.cat([x, input], dim=-1)
layer = getattr(self, "h_layer_{}".format(i))
x = layer(x)
output = self.final_linear(x)
if self.squeeze_out:
output = torch.sigmoid(output)
else:
output = output
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment