Skip to content

Instantly share code, notes, and snippets.

@kks32
Last active April 8, 2024 20:22
Show Gist options
  • Save kks32/ecfa18f6d5e62dae396fa1e9da02af07 to your computer and use it in GitHub Desktop.
Save kks32/ecfa18f6d5e62dae396fa1e9da02af07 to your computer and use it in GitHub Desktop.
Spherical Linear Parametrization
  1. Spectral Normalization

Spectral normalization controls the Lipschitz constant of your weight matrix, which directly affects how much the transformation can stretch or shrink input signals. By constraining the Lipschitz constant, you're essentially bounding the potential change in output.

Advantages:

Gradient Stability: Makes training more stable, particularly for deep networks or discriminative models (like GANs). Output Bounds: Indirectly provides a bound on the output by limiting the stretching of the input space.

  1. Convex Hull Constraints

Consider the training data as points in input space. Construct their convex hull (the smallest convex shape enclosing all the data). Restrict the weight vectors to directions that stay within a scaled version of this convex hull.

Advantages:

Adaptation to data distribution: The output bound is naturally related to the spread of your input data distribution. Guarantees: With perfect constraint enforcement, this method can theoretically guarantee bounded outputs. Challenges:

Computational Complexity: Can get computationally expensive, especially in high dimensional input spaces.

  1. GeoDef (Geometric Vector Perceptron) proposed by Sitzmann et al. (2020) in their paper "Implicit Geometric Regularization for Learning Shapes." The key idea is to represent the linear layer as a geometric transformation rather than a traditional matrix multiplication.

Here's how it works:

  • Instead of representing the linear layer as a weight matrix W and bias b, it is represented as a pair of vectors (v, c), where v is a direction vector and c is a translation vector.
  • The input x is first projected onto the direction vector v using the dot product: x' = (x · v) v.
  • The projected vector x' is then translated by the vector c: y = x' + c.

This approach has several advantages:

  • It avoids the issue of unbounded gradients when the angle approaches 90 degrees, as the projection step ensures that the output is bounded.
  • It reduces the number of parameters since only two vectors (v and c) are required instead of a full weight matrix and bias vector.
  • It preserves the relative magnitudes of the input vectors across layers, as no projection onto a ball of fixed radius is performed.
  • It can be easily extended to higher dimensions by using multiple direction vectors and translation vectors.
# "GeoDef" (Geometric Vector Perceptron) proposed by Sitzmann et al. (2020) in their paper
# "Implicit Geometric Regularization for Learning Shapes." The key idea is to represent the
# linear layer as a geometric transformation rather than a traditional matrix multiplication.
class GeoDef(nn.Module):
def __init__(self, in_features, out_features):
super(GeoDef, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.direction_vectors = nn.Parameter(torch.empty(out_features, in_features))
self.translation_vectors = nn.Parameter(torch.empty(out_features))
self.reset_parameters()
def reset_parameters(self):
# Initialize the direction vectors with random values and normalize them
nn.init.normal_(self.direction_vectors)
self.direction_vectors.data = self.direction_vectors.data / self.direction_vectors.data.norm(dim=1, keepdim=True)
# Initialize the translation vectors with zeros
nn.init.zeros_(self.translation_vectors)
def forward(self, x):
# Project the input onto the direction vectors
projected = torch.matmul(self.direction_vectors, x.t()).t()
# Apply the translation vectors
output = projected + self.translation_vectors.unsqueeze(0)
return output
import torch
import torch.nn as nn
class GeoDefConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
super(GeoDefConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.stride = (stride, stride) if isinstance(stride, int) else stride
self.padding = (padding, padding) if isinstance(padding, int) else padding
self.direction_vectors = nn.Parameter(torch.empty(out_channels, in_channels, *self.kernel_size))
if bias:
self.translation_vectors = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter('translation_vectors', None)
self.reset_parameters()
def reset_parameters(self):
# Initialize the direction vectors with random values and normalize them
nn.init.kaiming_uniform_(self.direction_vectors, a=5**0.5)
self.direction_vectors.data = nn.functional.normalize(self.direction_vectors.data, dim=[1, 2, 3])
# Initialize the translation vectors with zeros
if self.translation_vectors is not None:
nn.init.zeros_(self.translation_vectors)
def forward(self, x):
# Perform the convolutional operation using the direction vectors
projected = nn.functional.conv2d(x, self.direction_vectors, stride=self.stride, padding=self.padding)
# Add the translation vectors to the projected values
if self.translation_vectors is not None:
projected += self.translation_vectors.view(1, -1, 1, 1)
return projected
import torch
import torch.nn as nn
class HouseholderLinear(nn.Module):
def __init__(self, in_features, out_features, num_reflections=None):
super(HouseholderLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
if num_reflections is None:
num_reflections = min(in_features, out_features)
self.num_reflections = num_reflections
# Initialize the Householder reflection vectors
self.reflections = nn.Parameter(torch.randn(num_reflections, in_features))
# Initialize the output projection matrix
self.projection = nn.Parameter(torch.randn(out_features, in_features))
# Initialize the bias terms
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, input):
# Apply the Householder reflections to the input
x = input
for i in range(self.num_reflections):
v = self.reflections[i]
x = x - 2 * torch.matmul(x, v.unsqueeze(1)).matmul(v.unsqueeze(0)) / torch.dot(v, v)
# Project the transformed input to the output space
output = torch.matmul(x, self.projection.transpose(0, 1))
# Add the bias terms
output = output + self.bias
return output
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MyModel, self).__init__()
self.fc1 = SphericalLinear(input_size, hidden_size)
self.fc2 = SphericalLinear(hidden_size, output_size)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
import torch
import torch.nn as nn
class NormPreservingLinear(nn.Module):
def __init__(self, in_features, out_features):
super(NormPreservingLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
# Initialize the weight matrix
self.weight = nn.Parameter(torch.randn(out_features, in_features))
# Initialize the bias terms
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, input):
# Normalize the weight matrix to have unit norm columns
weight_norm = torch.norm(self.weight, dim=1, keepdim=True)
normalized_weight = self.weight / weight_norm
# Compute the linear transformation
output = torch.matmul(input, normalized_weight.transpose(0, 1))
# Add the bias terms
output = output + self.bias
return output
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
# Generate training data
x_data, y_data = np.meshgrid(np.linspace(-np.pi, np.pi, 50), np.linspace(-np.pi, np.pi, 50))
f_data = np.sin(x_data) * np.sin(y_data)
x_data, y_data, f_data = torch.tensor(x_data.flatten()).float().view(-1, 1), torch.tensor(y_data.flatten()).float().view(-1, 1), torch.tensor(f_data.flatten()).float().view(-1, 1)
# plot training data in 2D with matplotlib as a contour plot
plt.contourf(x_data.view(50, 50), y_data.view(50, 50), f_data.view(50, 50), 100, cmap='seismic')
plt.colorbar()
plt.show()
"""We will use a four-layer NN with 10 neurons each and a hyperbolic-tangent activation function at each layer"""
class StandardNN(nn.Module):
def __init__(self):
super(StandardNN, self).__init__()
self.layers = nn.Sequential(
nn.Linear(in_features=2, out_features=10),
nn.Tanh(),
nn.Linear(in_features=10, out_features=10),
nn.Tanh(),
nn.Linear(in_features=10, out_features=10),
nn.Tanh(),
nn.Linear(10, 1)
)
def forward(self, x, y):
xy = torch.cat((x, y), dim=1)
return self.layers(xy)
# Define model
model = StandardNN()
"""Once the networks are initialized, we set up the optimization problem and train the network by minimizing an objective function, i.e. solving the optimization problem for $W$ and $b$. The optimization problem for a data-driven curve-fitting is defined as:
$$
\text{arg min}_{W,b} \mathcal{L}(W, b) := \left\| f(x^*, y^*) - \mathcal{N}_f(x^*, y^*; W, b) \right\|
$$
where $x^*, y^*$ is the set of all discrete points where $f$ is given. For the loss-function $\left\| \circ \right\|$, we use the mean squared-error norm
$$
\left\| \circ \right\| = \frac{1}{N} \sum_{x^*, y^* \in I} \left( f(x^*, y^*) - \hat{f}(x^*, y^*) \right)^2.
$$
"""
# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
losses = []
# Training loop
for epoch in tqdm(range(5000), desc='Data-driven model training progress'):
optimizer.zero_grad()
f_pred = model(x_data, y_data)
loss = criterion(f_pred, f_data)
loss.backward()
optimizer.step()
losses.append(loss.item())
# Plot the loss on a semilog scale
plt.figure()
plt.semilogy(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, which="both", ls="--", linewidth=0.5)
plt.show()
"""> 💡 When the loss function is non-smooth like the one shown here, it means we are probably using a higher learning rate. Try reducing the learning rate to 0.005 and 0.001 to see the effect on the loss evolution
### Data-driven Neural Network implementation
"""
import torch
import torch.nn as nn
class HouseholderLinear(nn.Module):
def __init__(self, in_features, out_features, num_reflections=None):
super(HouseholderLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
if num_reflections is None:
num_reflections = min(in_features, out_features)
self.num_reflections = num_reflections
# Initialize the Householder reflection vectors
self.reflections = nn.Parameter(torch.randn(num_reflections, in_features))
# Initialize the output projection matrix
self.projection = nn.Parameter(torch.randn(out_features, in_features))
# Initialize the bias terms
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, input):
# Apply the Householder reflections to the input
x = input
for i in range(self.num_reflections):
v = self.reflections[i]
x = x - 2 * torch.matmul(x, v.unsqueeze(1)).matmul(v.unsqueeze(0)) / torch.dot(v, v)
# Project the transformed input to the output space
output = torch.matmul(x, self.projection.transpose(0, 1))
# Add the bias terms
output = output + self.bias
return output
import torch
import torch.nn as nn
class NormPreservingLinear(nn.Module):
def __init__(self, in_features, out_features):
super(NormPreservingLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
# Initialize the weight matrix
self.weight = nn.Parameter(torch.randn(out_features, in_features))
# Initialize the bias terms
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, input):
# Normalize the weight matrix to have unit norm columns
weight_norm = torch.norm(self.weight, dim=1, keepdim=True)
normalized_weight = self.weight / weight_norm
# Compute the linear transformation
output = torch.matmul(input, normalized_weight.transpose(0, 1))
# Add the bias terms
output = output + self.bias
return output
class SphericalLinear(nn.Module):
def __init__(self, in_features, out_features, radius=1.0):
super(SphericalLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.radius = radius
# Initialize the angular coordinates for the weight vectors
self.angles = nn.Parameter(torch.randn(out_features, in_features))
# Initialize the bias terms
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, input):
# Compute the weight vectors from the angular coordinates
weights = self.radius * torch.cos(self.angles)
# Compute the linear transformation
output = torch.matmul(weights, input.t()) + self.bias.unsqueeze(1)
return output.t()
import torch
import torch.nn as nn
class AdaptiveSphericalLinear(nn.Module):
def __init__(self, in_features, out_features, radius_init=1.0, learnable_radius=True):
super(AdaptiveSphericalLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
# Initialize the angular coordinates for the weight vectors
self.angles = nn.Parameter(torch.randn(out_features, in_features))
# Initialize the bias terms
self.bias = nn.Parameter(torch.zeros(out_features))
# Initialize the radius parameter
if learnable_radius:
self.radius = nn.Parameter(torch.tensor(radius_init))
else:
self.register_buffer('radius', torch.tensor(radius_init))
def forward(self, input):
# Compute the weight vectors from the angular coordinates
cos_angles = torch.cos(self.angles)
sin_angles = torch.sin(self.angles)
weights = self.radius * cos_angles
# Compute the linear transformation
output = torch.matmul(input, weights.transpose(0, 1))
# Add the bias terms
output = output + self.bias
return output
def extra_repr(self):
return f'in_features={self.in_features}, out_features={self.out_features}, ' \
f'radius={self.radius.item():.4f}'
import torch
import torch.nn as nn
class OrthogonalLinear(nn.Module):
def __init__(self, in_features, out_features):
super(OrthogonalLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
# Initialize orthogonal matrix Q and diagonal scaling matrix S
self.Q = nn.Parameter(torch.eye(in_features))
self.S = nn.Parameter(0.01 * torch.ones(in_features))
# Initialize the weight matrix and bias vector
self.weight = nn.Parameter(torch.zeros(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, input):
# Compute the linear transformation using orthogonal matrix Q and diagonal scaling matrix S
output = torch.matmul(input, self.Q * self.S)
# Project the transformed input to the output dimension
output = torch.matmul(output, self.weight.t())
# Add the bias term
output = output + self.bias
return output
def update_Q(self, grad_Q, learning_rate):
# Map gradient to skew-symmetric matrix
skew_grad = 0.5 * (grad_Q - grad_Q.t())
# Compute update step in Lie algebra
exp_update = torch.matrix_exp(learning_rate * torch.matmul(self.Q.t(), skew_grad))
# Update orthogonal matrix Q
self.Q.data = torch.matmul(self.Q.data, exp_update)
def update_S(self, grad_S, learning_rate):
# Update diagonal scaling matrix S using gradient descent
self.S.data -= learning_rate * grad_S
import torch
import torch.nn as nn
import math
class ParametricPolarTransformation(nn.Module):
def __init__(self, in_features, out_features):
super(ParametricPolarTransformation, self).__init__()
self.in_features = in_features
self.out_features = out_features
# Ensure out_features is even since we're dealing with pairs representing complex numbers or 2D vectors
if out_features % 2 != 0:
raise ValueError("out_features must be an even number.")
# Number of transformations needed
self.num_transforms = out_features // 2
# Initialize the magnitude (radius) and angle parameters for each transformation
self.radius = nn.Parameter(torch.ones(self.num_transforms))
self.angle = nn.Parameter(torch.zeros(self.num_transforms))
def forward(self, x):
batch_size = x.size(0)
# Initialize an output tensor
y = torch.zeros(batch_size, self.out_features, device=x.device)
# Apply each transformation to all input vectors
for i in range(self.num_transforms):
# Apply transformation
transformed = self.apply_transformation(x, self.radius[i], self.angle[i])
# Place the result in the output tensor
y[:, 2*i:2*i+2] = transformed
return y
def apply_transformation(self, x, radius, angle):
# Assuming x has shape [batch_size, in_features]
# and in_features >= 2 and is even. For simplicity, we just use the first 2 features.
r_x, theta_x = self.cartesian_to_polar(x[:, :2])
# Apply the transformation
r_prime = radius * r_x
theta_prime = theta_x + angle
# Convert back to Cartesian coordinates
return self.polar_to_cartesian(r_prime, theta_prime)
@staticmethod
def cartesian_to_polar(x):
r_x = torch.sqrt(x[:, 0]**2 + x[:, 1]**2)
theta_x = torch.atan2(x[:, 1], x[:, 0])
return r_x, theta_x
@staticmethod
def polar_to_cartesian(r, theta):
x = r * torch.cos(theta)
y = r * torch.sin(theta)
return torch.stack((x, y), dim=-1)
import torch
import torch.nn as nn
import torch.nn.functional as F
class GeometricLinearLayer(nn.Module):
def __init__(self, in_features, out_features):
super(GeometricLinearLayer, self).__init__()
self.phi = nn.Parameter(torch.randn(in_features)) # Direction parameters
self.radius = nn.Parameter(torch.randn(1)) # Radius parameter
self.bias = nn.Parameter(torch.randn(out_features)) # Bias
def forward(self, x):
u = self.phi / torch.norm(self.phi) # Normalize to get unit vector
r = torch.sigmoid(self.radius) # Ensure radius is bounded [0, 1]
y_prime = r * torch.matmul(x, u) + self.bias # Linear transformation
return y_prime
# Example usage
in_features, out_features = 10, 1
model = GeometricLinearLayer(in_features, out_features)
input_tensor = torch.randn(5, in_features) # Batch size of 5
output = model(input_tensor)
print(output)
import torch
import torch.nn as nn
import torch.nn.functional as F
class GeometricLinearLayer(nn.Module):
def __init__(self, in_features, out_features):
super(GeometricLinearLayer, self).__init__()
self.phi = nn.Parameter(torch.randn(in_features)) # Direction parameters
self.radius = nn.Parameter(torch.randn(1)) # Radius parameter
self.bias = nn.Parameter(torch.randn(out_features)) # Bias
def forward(self, x):
u = self.phi / torch.norm(self.phi) # Normalize to get unit vector
u = u.unsqueeze(0) # Add batch dimension for broadcasting
r = torch.sigmoid(self.radius) # Ensure radius is bounded [0, 1]
# Perform element-wise multiplication and sum across in_features dimension
# This is effectively a dot product, scaling each input by the radius
y_prime = r * torch.sum(x * u, dim=1, keepdim=True) + self.bias
return y_prime
import torch
import torch.nn as nn
import torch.nn.functional as F
class GeometricLinearLayer(nn.Module):
def __init__(self, in_features, out_features, bias=True, radius=1.0):
super(GeometricLinearLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.radius = radius
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, input):
# Project the input onto the hypersphere
input_norm = input.norm(dim=1, keepdim=True)
input_normed = input / (input_norm + 1e-8)
# Project the weights onto the hypersphere
weight_norm = self.weight.norm(dim=1, keepdim=True)
weight_normed = self.weight / (weight_norm + 1e-8)
# Compute the output
output = F.linear(input_normed, weight_normed, self.bias)
# Bound the output to the specified radius
output_norm = output.norm(dim=1, keepdim=True)
output_bounded = output * torch.minimum(self.radius / output_norm, torch.ones_like(output_norm))
return output_bounded
import torch
import torch.nn as nn
class CauchyKernel(nn.Module):
def __init__(self, in_features, out_features):
super(CauchyKernel, self).__init__()
self.in_features = in_features
self.out_features = out_features
# Initialize learnable parameters a and b
self.a = nn.Parameter(torch.randn(out_features, in_features))
self.b = nn.Parameter(torch.randn(out_features, in_features)) # b now matches dimensionality of a
def forward(self, x):
# Ensure x is 2D (batch_size, in_features)
if x.dim() == 1:
x = x.unsqueeze(0)
# Compute pairwise squared Euclidean distance between x and b
# This line calculates (x - b)^2 for each feature and output
squared_distances = torch.sum((x.unsqueeze(1) - self.b.unsqueeze(0)) ** 2, dim=2)
# Apply the Cauchy kernel formula
# Note: You might need to adjust this formula based on your specific version of the Cauchy kernel
# Adding self.a as a scaling factor, but you might want to revise its role
y = 1 / (1 + squared_distances)
return y
import torch
import torch.nn as nn
class BLO(nn.Module):
def __init__(self, in_features, out_features, bias=True, bound=1.0):
super(BLO, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.bound = bound
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
self.weight.data = self.bound * self.weight.data / self.weight.data.norm(dim=1, keepdim=True)
def forward(self, input):
output = torch.matmul(input, self.weight.t())
if self.bias is not None:
output += self.bias
return output
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
# Define the standard neural network model
class StandardNN(nn.Module):
def __init__(self):
super(StandardNN, self).__init__()
self.layers = nn.Sequential(
nn.Linear(in_features=2, out_features=10),
nn.Tanh(),
nn.Linear(in_features=10, out_features=10),
nn.Tanh(),
nn.Linear(in_features=10, out_features=10),
nn.Tanh(),
nn.Linear(10, 1)
)
def forward(self, x, y):
xy = torch.cat((x, y), dim=1)
return self.layers(xy)
# Generate training data
x_data, y_data = np.meshgrid(np.linspace(-np.pi, np.pi, 50), np.linspace(-np.pi, np.pi, 50))
f_data = np.sin(x_data) * np.sin(y_data)
x_data, y_data, f_data = torch.tensor(x_data.flatten()).float().view(-1, 1), torch.tensor(y_data.flatten()).float().view(-1, 1), torch.tensor(f_data.flatten()).float().view(-1, 1)
# Define model
model = StandardNN()
# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
losses = []
# Training loop
for epoch in tqdm(range(5000), desc='Data-driven model training progress'):
optimizer.zero_grad()
f_pred = model(x_data, y_data)
loss = criterion(f_pred, f_data)
loss.backward()
optimizer.step()
losses.append(loss.item())
# Plot the loss on a semilog scale
plt.figure()
plt.semilogy(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, which="both", ls="--", linewidth=0.5)
plt.show()
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
# Define the standard neural network model
class StandardNN(nn.Module):
def __init__(self):
super(StandardNN, self).__init__()
nfeatures = 8
self.layers = nn.Sequential(
BLO(in_features=2, out_features=nfeatures),
nn.Tanh(),
BLO(in_features=nfeatures, out_features=nfeatures),
nn.Tanh(),
BLO(in_features=nfeatures, out_features=nfeatures),
nn.Tanh(),
BLO(nfeatures, 1)
)
def forward(self, x, y):
xy = torch.cat((x, y), dim=1)
return self.layers(xy)
# Generate training data
x_data, y_data = np.meshgrid(np.linspace(-np.pi, np.pi, 50), np.linspace(-np.pi, np.pi, 50))
f_data = np.sin(x_data) * np.sin(y_data)
x_data, y_data, f_data = torch.tensor(x_data.flatten()).float().view(-1, 1), torch.tensor(y_data.flatten()).float().view(-1, 1), torch.tensor(f_data.flatten()).float().view(-1, 1)
# Define model
model = StandardNN()
# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
losses = []
# Training loop
for epoch in tqdm(range(5000), desc='Data-driven model training progress'):
optimizer.zero_grad()
f_pred = model(x_data, y_data)
loss = criterion(f_pred, f_data)
loss.backward()
optimizer.step()
losses.append(loss.item())
# Plot the loss on a semilog scale
plt.figure()
plt.semilogy(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, which="both", ls="--", linewidth=0.5)
plt.show()
"""## Test the data-driven model"""
# Test the model
xyrange = 1 * np.pi
x_test, y_test = np.meshgrid(np.linspace(-xyrange, xyrange, 100), np.linspace(-xyrange, xyrange, 100))
f_test = np.sin(x_test) * np.sin(y_test)
x_test, y_test = torch.tensor(x_test).float().view(-1, 1), torch.tensor(y_test).float().view(-1, 1)
f_pred = model(x_test, y_test).detach().numpy().reshape(100, 100)
# Plotting
fig, ax = plt.subplots(1, 2)
im = ax[0].pcolor(x_test.numpy().reshape(100, 100), y_test.numpy().reshape(100, 100), f_test, cmap='seismic')
plt.colorbar(im, ax=ax[0])
im = ax[1].pcolor(x_test.numpy().reshape(100, 100), y_test.numpy().reshape(100, 100), f_pred, cmap='seismic')
plt.colorbar(im, ax=ax[1])
plt.show()
# Test the model
xyrange = 1 * np.pi
x_test, y_test = np.meshgrid(np.linspace(-xyrange, xyrange, 100), np.linspace(-xyrange, xyrange, 100))
f_test = np.sin(x_test) * np.sin(y_test)
x_test, y_test = torch.tensor(x_test).float().view(-1, 1), torch.tensor(y_test).float().view(-1, 1)
f_pred = model(x_test, y_test).detach().numpy().reshape(100, 100)
# Plotting
fig, ax = plt.subplots(1, 2)
im = ax[0].pcolor(x_test.numpy().reshape(100, 100), y_test.numpy().reshape(100, 100), f_test, cmap='seismic')
plt.colorbar(im, ax=ax[0])
im = ax[1].pcolor(x_test.numpy().reshape(100, 100), y_test.numpy().reshape(100, 100), f_pred, cmap='seismic')
plt.colorbar(im, ax=ax[1])
plt.show()
"""### Plot error"""
import plotly.graph_objects as go
error = f_test - f_pred
fig = go.Figure(data=[go.Surface(z=error, x=x_test.numpy().reshape(100, 100), y=y_test.numpy().reshape(100, 100))])
fig.update_layout(title='Error between prediction and original data',
autosize=False, width=800, height=700,
margin=dict(l=65, r=50, b=65, t=90))
fig.show()
"""### Extrapolate outside training range"""
# Test the model
xyrange = 2 * np.pi
x_test, y_test = np.meshgrid(np.linspace(-xyrange, xyrange, 100), np.linspace(-xyrange, xyrange, 100))
f_test = np.sin(x_test) * np.sin(y_test)
x_test, y_test = torch.tensor(x_test).float().view(-1, 1), torch.tensor(y_test).float().view(-1, 1)
f_pred = model(x_test, y_test).detach().numpy().reshape(100, 100)
# Plotting
fig, ax = plt.subplots(1, 2)
im = ax[0].pcolor(x_test.numpy().reshape(100, 100), y_test.numpy().reshape(100, 100), f_test, cmap='seismic')
plt.colorbar(im, ax=ax[0])
im = ax[1].pcolor(x_test.numpy().reshape(100, 100), y_test.numpy().reshape(100, 100), f_pred, cmap='seismic')
plt.colorbar(im, ax=ax[1])
plt.show()
"""We notice that the NN gave a reasonable approximation within the training ranges, however, when we extrapolate, we lose the predictive ability of NN. This is typical for most NN. Let's now explore how embedding physics in NN helps minimize the issue of lack of generalizability outside the training range."""

n this approach, instead of representing the weights (W) as a standard matrix, we reparameterize them as points on a high-dimensional sphere with a fixed radius (R).

Let's consider a fully connected layer with input x ∈ R^n and output y ∈ R^m. The traditional linear transformation is:

y = Wx + b

Where W ∈ R^(m×n) is the weight matrix, and b ∈ R^m is the bias vector.

In the spherical weight parameterization, we represent the weight matrix W as a collection of m weight vectors, each lying on an n-dimensional sphere of radius R:

W = [w_1^T, w_2^T, ..., w_m^T]^T

Where:

w_i ∈ R^n is the i-th weight vector
||w_i|| = R for all i (i.e., each weight vector has a norm or length of R)

This parameterization ensures that the linear transformation (Wx) is always bounded, as the magnitude of each weight vector is constrained by the sphere's radius R.

Mathematically, the linear transformation can be expressed as:

y_i = w_i^T x + b_i

Where y_i is the i-th output component, and b_i is the corresponding bias term.

During training, instead of directly learning the weight matrix W, we learn the angular coordinates (θ_i,1, θ_i,2, ..., θ_i,n) that determine the direction of each weight vector w_i on the n-dimensional sphere. The weight vectors can then be reconstructed as:

w_i = R * [cos(θ_i,1), cos(θ_i,2), ..., cos(θ_i,n)]^T

By learning the angular coordinates and the radius R (which can be a learnable parameter or a fixed constant), we effectively constrain the weights to lie on the high-dimensional sphere, ensuring that the linear transformation remains bounded.

The bias terms b_i can be learned normally, as they do not affect the boundedness of the linear transformation.

This spherical weight parameterization allows the model to explore large weight magnitudes without causing unbounded outputs, as the weights are geometrically constrained to a fixed radius. It provides a way to control the weights geometrically while ensuring that the outputs remain bounded, without relying on techniques like Highway Networks or explicit normalization.

import torch
import torch.nn as nn
class SphericalLinear(nn.Module):
def __init__(self, in_features, out_features, radius=1.0):
super(SphericalLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.radius = radius
# Initialize the angular coordinates for the weight vectors
self.angles = nn.Parameter(torch.randn(out_features, in_features))
# Initialize the bias terms
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, input):
# Compute the weight vectors from the angular coordinates
weights = self.radius * torch.cos(self.angles)
# Compute the linear transformation
output = torch.matmul(weights, input.t()) + self.bias.unsqueeze(1)
return output.t()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment