Skip to content

Instantly share code, notes, and snippets.

@V0XNIHILI
Last active June 8, 2023 17:29
Show Gist options
  • Save V0XNIHILI/5d4dd7e12c712122912462e1a1d97554 to your computer and use it in GitHub Desktop.
Save V0XNIHILI/5d4dd7e12c712122912462e1a1d97554 to your computer and use it in GitHub Desktop.
Initial take at variable size linear output layer
import torch.nn as nn
import torch
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, input_size=16, hidden_size=32, initial_output_size=5):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.initial_output_size = initial_output_size
self.embedder = nn.Linear(self.input_size, self.hidden_size)
self.linear = nn.Linear(self.hidden_size, initial_output_size)
def forward(self, x: torch.Tensor):
x = self.embedder(x)
x = self.linear(x)
return x
def expand_linear_layer(linear_layer: nn.Linear, output_size_increment: int):
"""Expand the output size of a linear layer.
Args:
linear_layer (nn.Linear): Linear layer to be expanded.
output_size_increment (int): Increment of output size.
"""
input_size = linear_layer.in_features
output_size = linear_layer.out_features
new_output_size = output_size + output_size_increment
new_weight = torch.randn(new_output_size, input_size).to(linear_layer.weight.device)
new_bias = torch.randn(new_output_size).to(linear_layer.bias.device)
new_weight[:output_size] = linear_layer.weight.data
new_bias[:output_size] = linear_layer.bias.data
linear_layer.in_features = input_size
linear_layer.out_features = new_output_size
linear_layer.weight= torch.nn.Parameter(new_weight)
linear_layer.bias = torch.nn.Parameter(new_bias)
def compute_restricted_outputs(outputs: torch.Tensor, output_size_increment: int):
"""Mask out all but the last output_size_increment number of outputs.
Args:
outputs (torch.Tensor): Output tensor from a linear layer.
output_size_increment (int): Number of outputs to keep.
Returns:
torch.Tensor: Masked output tensor.
"""
num_classes = outputs.size(1)
restricted_mask = torch.zeros(num_classes)
restricted_mask[-output_size_increment:] = 1.0
restricted_mask = restricted_mask.to(outputs.device)
# Apply the restricted mask to the outputs
restricted_outputs = outputs * restricted_mask
return restricted_outputs
# -------------------------------------------------------------------------------
batch_size = 32
initial_linear_layer_size = 5
output_size_increment = 5
steps = 4
input_size = 16
hidden_size = 4
net = Net(input_size, hidden_size, initial_linear_layer_size)
criterion = nn.CrossEntropyLoss()
for step in range(steps):
opt = torch.optim.Adam(net.parameters(), lr=0.01)
inputs = torch.randn(batch_size, input_size)
targets = torch.randint(0, output_size_increment, (batch_size,))
outputs = net(inputs)
outputs = compute_restricted_outputs(outputs, output_size_increment)
loss = criterion(outputs, targets + net.linear.weight.data.shape[0] - output_size_increment)
loss.backward()
opt.step()
if step != steps - 1:
expand_linear_layer(net.linear, output_size_increment)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment