Skip to content

Instantly share code, notes, and snippets.

@Algomancer
Last active June 19, 2024 02:13
Show Gist options
  • Save Algomancer/8f0b8d7cc26657659af663d9ab1721a0 to your computer and use it in GitHub Desktop.
Save Algomancer/8f0b8d7cc26657659af663d9ab1721a0 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SubspaceLinear(nn.Linear):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the BaseSubspaceLinear layer. Calls `subspace_weights` to sample from the subspace
and uses the corresponding weight and bias.
Parameters
----------
x : torch.Tensor
The input tensor.
Returns
-------
torch.Tensor
The output tensor after applying the linear transformation.
"""
w, b = self.subspace_weights()
return F.linear(x, w, b)
class LineLinear(SubspaceLinear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weight_alt = nn.Parameter(torch.zeros_like(self.weight))
self.bias_alt = nn.Parameter(torch.zeros_like(self.bias))
def initialize_parameters(self, init_fn) -> None:
"""
Initialize the additional weights using a provided function.
Parameters
----------
init_fn : Callable
The function to initialize the weights.
"""
init_fn(self.weight_alt)
init_fn(self.bias_alt)
def subspace_weights(self) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute the weight and bias as a linear combination of two sets of parameters.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
The combined weight and bias tensors.
"""
w = (1 - self.alpha) * self.weight + self.alpha * self.weight_alt
b = (1 - self.alpha) * self.bias + self.alpha * self.bias_alt
return w, b
# Then during training loop, randomly set alpha uniform[0, 1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment