Skip to content

Instantly share code, notes, and snippets.

@tljstewart
Last active September 16, 2023 02:15
Show Gist options
  • Save tljstewart/29803625455a4a6df3ae760fe655cef7 to your computer and use it in GitHub Desktop.
Save tljstewart/29803625455a4a6df3ae760fe655cef7 to your computer and use it in GitHub Desktop.
Testing the concept of zero init cnn layers from control net paper
import torch
import torch.nn as nn
import torch.optim as optim
# Generate some random data
batch_size = 32
channels = 3
height = 32
width = 32
x_train = torch.randn(batch_size, channels, height, width)
y_train = torch.randn(batch_size, 10)
# Define the original model
class OriginalCNN(nn.Module):
def __init__(self):
super(OriginalCNN, self).__init__()
self.conv1 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(16*32*32, 10)
def forward(self, x):
x = self.conv1(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
def insert_optimizer_params(optimizer, new_layer):
for param_group in optimizer.param_groups:
param_group['params'].extend(new_layer.parameters())
# Create and train the original model
model = OriginalCNN()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# print(optimizer)
criterion = nn.MSELoss()
for epoch in range(100):
optimizer.zero_grad()
output = model(x_train)
loss = criterion(output, y_train)
loss.backward()
optimizer.step()
print("Weights of the original convolutional layer:")
print(model.conv1.weight.data[0])
# Add a new convolutional layer with zero-initialized weights
model.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
model.conv2.weight.data.fill_(0.0)
model.conv2.bias.data.fill_(0.0)
# print("Weights of the new convolutional layer BEFORE train:\n")
# print(model.conv2.weight.data[0])
# Update forward method to use the new layer
def new_forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
# replace forward method and add new layer parameters to optimizer else weights will remain 0
setattr(model, "forward", new_forward.__get__(model))
REUSEOPTIMIZER = True
if REUSEOPTIMIZER: insert_optimizer_params(optimizer, model.conv2)
# print(optimizer.state_dict())
else: optimizer = optim.SGD(model.parameters(), lr=0.01)
# print(optimizer.state_dict())
# Train the new model
for epoch in range(100):
optimizer.zero_grad()
output = model(x_train)
loss = criterion(output, y_train)
loss.backward()
optimizer.step()
# Print the weights of the new convolutional layer
print("Weights of the original convolutional layer AFTER train:\n")
print(model.conv1.weight.data[0])
print("Weights of the new convolutional layer AFTER train:\n")
print(model.conv2.weight.data[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment