Skip to content

Instantly share code, notes, and snippets.

@crazyoscarchang
Last active January 26, 2023 10:18
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save crazyoscarchang/c9a11b67c420202da1f26e0d20786750 to your computer and use it in GitHub Desktop.
Save crazyoscarchang/c9a11b67c420202da1f26e0d20786750 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import math
import numpy as np
# Hardcoded variables for hyperfan init
hardcoded_input_size = 3
hardcoded_n_classes = 10
hardcoded_hyperfanin = [hardcoded_input_size]*hardcoded_input_size + [96]*96*4 + [192]*(192*8 + 2*hardcoded_n_classes)
hardcoded_hyperfanout = [96]*(hardcoded_input_size + 96*2) + [192]*(192*9) + [hardcoded_n_classes]*2*hardcoded_n_classes
hardcoded_receptive = lambda i: 9 if i < hardcoded_input_size + 192*8 else 1
def hyperfaninWi_init(i):
def hyperfanin_init(Wi):
fan_out, fan_in = Wi.size(0), Wi.size(1)
bound = math.sqrt(3*2 / (fan_in * hardcoded_hyperfanin[i]) / hardcoded_receptive(i))
Wi.uniform_(-bound, bound)
return Wi
return hyperfanin_init
def hyperfanoutWi_init(i):
def hyperfanout_init(Wi):
fan_out, fan_in = Wi.size(0), Wi.size(1)
bound = math.sqrt(3*2 / (fan_in * hardcoded_hyperfanout[i]) / hardcoded_receptive(i))
Wi.uniform_(-bound, bound)
return Wi
return hyperfanout_init
def fanin_uniform(W):
fan_out, fan_in = W.size(0), W.size(1)
bound = math.sqrt(3*2 / fan_in)
W.uniform_(-bound, bound)
return W
def embed_uniform(e):
bound = math.sqrt(3)
e.uniform_(-bound, bound)
return e
# Adapted from https://github.com/StefOe/all-conv-pytorch/blob/master/allconv.py
class AllConvNet(nn.Module):
def __init__(self, input_size, n_classes):
super(AllConvNet, self).__init__()
self.input_size = input_size
self.n_classes = n_classes
def forward(self, x):
x_drop = F.dropout(x, .2)
conv1_out = F.relu(F.conv2d(x_drop, self.conv1_weight, self.conv1_bias, padding=1))
conv2_out = F.relu(F.conv2d(conv1_out, self.conv2_weight, self.conv2_bias, padding=1))
conv3_out = F.relu(F.conv2d(conv2_out, self.conv3_weight, self.conv3_bias, padding=1, stride=2))
conv3_out_drop = F.dropout(conv3_out, .5)
conv4_out = F.relu(F.conv2d(conv3_out_drop, self.conv4_weight, self.conv4_bias, padding=1))
conv5_out = F.relu(F.conv2d(conv4_out, self.conv5_weight, self.conv5_bias, padding=1))
conv6_out = F.relu(F.conv2d(conv5_out, self.conv6_weight, self.conv6_bias, padding=1, stride=2))
conv6_out_drop = F.dropout(conv6_out, .5)
conv7_out = F.relu(F.conv2d(conv6_out_drop, self.conv7_weight, self.conv7_bias, padding=1))
conv8_out = F.relu(F.conv2d(conv7_out, self.conv8_weight, self.conv8_bias))
class_out = F.relu(F.conv2d(conv8_out, self.class_conv_weight, self.class_conv_bias))
pool_out = F.adaptive_avg_pool2d(class_out, 1)
pool_out.squeeze_(-1)
pool_out.squeeze_(-1)
return pool_out
class HyperNN(AllConvNet):
def __init__(self, input_size, n_classes, embed_size, embedW_init_scheme,
hyperWi_init_scheme, hyperWout_init_scheme, device):
super().__init__(input_size, n_classes)
# Initialize the fixed parameters
self.num_kernels = input_size + 2*n_classes + 1920 # 96*2 + 192 + (192*2)*4
self.weight_embeddings = embedW_init_scheme(torch.zeros(self.num_kernels, embed_size).to(device))
# Initialize the trainable weight parameters
Wi = torch.zeros(self.num_kernels, embed_size, embed_size)
for i in range(self.num_kernels):
Wi[i] = hyperWi_init_scheme(i)(Wi[i])
Bi = torch.zeros(self.num_kernels, embed_size)
Wout = hyperWout_init_scheme(torch.zeros(96*9, embed_size))
Bout = torch.zeros(96*9)
# Register the trainable weight parameters
self.Wi = nn.Parameter(Wi)
self.Bi = nn.Parameter(Bi)
self.Wout = nn.Parameter(Wout)
self.Bout = nn.Parameter(Bout)
# Initialize and register the trainable bias parameters
self.conv1_bias = nn.Parameter(torch.zeros(96))
self.conv2_bias = nn.Parameter(torch.zeros(96))
self.conv3_bias = nn.Parameter(torch.zeros(96))
self.conv4_bias = nn.Parameter(torch.zeros(192))
self.conv5_bias = nn.Parameter(torch.zeros(192))
self.conv6_bias = nn.Parameter(torch.zeros(192))
self.conv7_bias = nn.Parameter(torch.zeros(192))
self.conv8_bias = nn.Parameter(torch.zeros(192))
self.class_conv_bias = nn.Parameter(torch.zeros(n_classes))
def forward(self, x):
# Generate main weights from HyperNet's parameters
idx = 0; jump = self.input_size
self.conv1_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(96, self.input_size, 3, 3)
idx += jump; jump = 96
self.conv2_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(96, 96, 3, 3)
idx += jump;
self.conv3_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(96, 96, 3, 3)
idx += jump; jump = 192
self.conv4_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(192, 96, 3, 3)
idx += jump; jump = 192*2
self.conv5_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(192, 192, 3, 3)
idx += jump;
self.conv6_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(192, 192, 3, 3)
idx += jump;
self.conv7_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(192, 192, 3, 3)
idx += jump;
self.conv8_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(192, 192, 3, 3)[:, :, :1, :1]
idx += jump; jump = self.n_classes * 2
self.class_conv_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(self.n_classes, 192, 3, 3)[:, :, :1, :1]
return super().forward(x)
# Configuration
device = 'cuda:0'
embed_size = 50
embedW_init_scheme = embed_uniform
hyperWi_init_scheme = hyperfaninWi_init # hyperfanoutWi_init
hyperWout_init_scheme = fanin_uniform
lr = 0.0005
training_batch_size = 100
test_batch_size = 1000
epochs = 500
log_interval = 100
seed = 123
torch.manual_seed(seed)
train_criterion = nn.CrossEntropyLoss(reduction='mean')
test_criterion = nn.CrossEntropyLoss(reduction='sum')
# Data for Loss Plots
train_loss_list = []
test_loss_list = []
test_acc_list = []
# Training/Testing Functions
def train(model, device, train_loader, optimizer, epoch, log_interval, lr_scheduler):
model.train()
total_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = train_criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
total_loss /= len(train_loader.dataset)
train_loss_list.append(total_loss)
lr_scheduler.step()
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += test_criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
test_loss_list.append(test_loss)
test_acc = 100. * correct / len(test_loader.dataset)
test_acc_list.append(test_acc)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
test_acc))
# CIFAR10 Data Loaders
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=training_batch_size, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=2)
# Model and Optimizer
model = HyperNN(hardcoded_input_size, hardcoded_n_classes, embed_size, embedW_init_scheme,
hyperWi_init_scheme, hyperWout_init_scheme, device).to(device)
num_params = sum([param.numel() for param in model.parameters()])
print("number of params:", num_params)
optimizer = optim.SGD(model.parameters(), lr=lr)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[350,450], gamma=0.1)
# Actual HyperNet Training
for epoch in range(1, epochs + 1):
train(model, device, train_loader, optimizer, epoch, log_interval, lr_scheduler)
test(model, device, test_loader)
# Save Experiment
result_dict = {'train_loss_list': np.array(train_loss_list),
'test_loss_list': np.array(test_loss_list),
'test_acc_list': np.array(test_acc_list)
}
torch.save(result_dict, 'results.dict')
@OhadRubin
Copy link

OhadRubin commented Oct 16, 2020

Can you publish a version without hardcoded values?
Thanks!

@Lovegood-1
Copy link

Hi,
I have been interested about the hypernetwork. The work you and you team have done is useful. However there is a line in your code that is difficult for me to understand even if I have read your paper several times.

My question is :
What does the 'hardcoded_receptive' means?

It will be helpful for me if you can give me a hint!!

Thanks!

@crazyoscarchang
Copy link
Author

Sorry for the late reply. We hardcode values depending on the sizes of the layers in the All Convolutional Net. In general, it might be difficult to have a hypernet initialization abstract enough that it can work off-the-shelf for different kinds of architectures in both the mainnet and the hypernet. We recommend adopting the general principle of preserving variance through the mainnet, and applying it to your specific neural network architecture.

@OhadRubin
Copy link

OhadRubin commented Aug 1, 2021

Can you give an example for a FeedForward network?
A simple one hidden layer network.
Your paper is very difficult to understand.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment