Skip to content

Instantly share code, notes, and snippets.

@sergeyprokudin
Created June 25, 2024 10:31
Show Gist options
  • Save sergeyprokudin/4ef2d3ad2c89a3835e9ae236cb099b05 to your computer and use it in GitHub Desktop.
Save sergeyprokudin/4ef2d3ad2c89a3835e9ae236cb099b05 to your computer and use it in GitHub Desktop.
Siren model
# @title Define SIREN deformation model
# https://github.com/vsitzmann/siren
# MIT License
# Copyright (c) 2020 Vincent Sitzmann
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from torch import nn
import numpy as np
class SineLayer(nn.Module):
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
# If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
# nonlinearity. Different signals may require different omega_0 in the first layer - this is a
# hyperparameter.
# If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
# activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
def __init__(self, in_features, out_features, bias=True,
is_first=False, omega_0=30):
super().__init__()
self.omega_0 = omega_0
self.is_first = is_first
self.in_features = in_features
self.linear = nn.Linear(in_features, out_features, bias=bias)
self.init_weights()
def init_weights(self):
with torch.no_grad():
if self.is_first:
self.linear.weight.uniform_(-1 / self.in_features,
1 / self.in_features)
else:
self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
np.sqrt(6 / self.in_features) / self.omega_0)
def forward(self, input):
return torch.sin(self.omega_0 * self.linear(input))
def forward_with_intermediate(self, input):
# For visualization of activation distributions
intermediate = self.omega_0 * self.linear(input)
return torch.sin(intermediate), intermediate
class Siren(nn.Module):
def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False,
first_omega_0=30, hidden_omega_0=30.):
super().__init__()
self.net = []
self.net.append(SineLayer(in_features, hidden_features,
is_first=True, omega_0=first_omega_0))
for i in range(hidden_layers):
self.net.append(SineLayer(hidden_features, hidden_features,
is_first=False, omega_0=hidden_omega_0))
if outermost_linear:
final_linear = nn.Linear(hidden_features, out_features)
with torch.no_grad():
final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
np.sqrt(6 / hidden_features) / hidden_omega_0)
self.net.append(final_linear)
else:
self.net.append(SineLayer(hidden_features, out_features,
is_first=False, omega_0=hidden_omega_0))
self.net = nn.Sequential(*self.net)
def forward(self, coords):
coords = coords#.requires_grad_(True) # allows to take derivative w.r.t. input
output = self.net(coords)
return output #, coords
def forward_with_activations(self, coords, retain_grad=False):
'''Returns not only model output, but also intermediate activations.
Only used for visualizing activations later!'''
activations = OrderedDict()
activation_count = 0
x = coords.clone().detach().requires_grad_(True)
activations['input'] = x
for i, layer in enumerate(self.net):
if isinstance(layer, SineLayer):
x, intermed = layer.forward_with_intermediate(x)
if retain_grad:
x.retain_grad()
intermed.retain_grad()
activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
activation_count += 1
else:
x = layer(x)
if retain_grad:
x.retain_grad()
activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
activation_count += 1
return activations
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment