Skip to content

Instantly share code, notes, and snippets.

@insujeon
Forked from crazyoscarchang/hypernet_cifar10.py
Created December 27, 2022 07:39
Show Gist options
  • Save insujeon/58e29f717d0792c63c5fd12150785781 to your computer and use it in GitHub Desktop.
Save insujeon/58e29f717d0792c63c5fd12150785781 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import math
import numpy as np
# Hardcoded variables for hyperfan init
hardcoded_input_size = 3
hardcoded_n_classes = 10
hardcoded_hyperfanin = [hardcoded_input_size]*hardcoded_input_size + [96]*96*4 + [192]*(192*8 + 2*hardcoded_n_classes)
hardcoded_hyperfanout = [96]*(hardcoded_input_size + 96*2) + [192]*(192*9) + [hardcoded_n_classes]*2*hardcoded_n_classes
hardcoded_receptive = lambda i: 9 if i < hardcoded_input_size + 192*8 else 1
def hyperfaninWi_init(i):
def hyperfanin_init(Wi):
fan_out, fan_in = Wi.size(0), Wi.size(1)
bound = math.sqrt(3*2 / (fan_in * hardcoded_hyperfanin[i]) / hardcoded_receptive(i))
Wi.uniform_(-bound, bound)
return Wi
return hyperfanin_init
def hyperfanoutWi_init(i):
def hyperfanout_init(Wi):
fan_out, fan_in = Wi.size(0), Wi.size(1)
bound = math.sqrt(3*2 / (fan_in * hardcoded_hyperfanout[i]) / hardcoded_receptive(i))
Wi.uniform_(-bound, bound)
return Wi
return hyperfanout_init
def fanin_uniform(W):
fan_out, fan_in = W.size(0), W.size(1)
bound = math.sqrt(3*2 / fan_in)
W.uniform_(-bound, bound)
return W
def embed_uniform(e):
bound = math.sqrt(3)
e.uniform_(-bound, bound)
return e
# Adapted from https://github.com/StefOe/all-conv-pytorch/blob/master/allconv.py
class AllConvNet(nn.Module):
def __init__(self, input_size, n_classes):
super(AllConvNet, self).__init__()
self.input_size = input_size
self.n_classes = n_classes
def forward(self, x):
x_drop = F.dropout(x, .2)
conv1_out = F.relu(F.conv2d(x_drop, self.conv1_weight, self.conv1_bias, padding=1))
conv2_out = F.relu(F.conv2d(conv1_out, self.conv2_weight, self.conv2_bias, padding=1))
conv3_out = F.relu(F.conv2d(conv2_out, self.conv3_weight, self.conv3_bias, padding=1, stride=2))
conv3_out_drop = F.dropout(conv3_out, .5)
conv4_out = F.relu(F.conv2d(conv3_out_drop, self.conv4_weight, self.conv4_bias, padding=1))
conv5_out = F.relu(F.conv2d(conv4_out, self.conv5_weight, self.conv5_bias, padding=1))
conv6_out = F.relu(F.conv2d(conv5_out, self.conv6_weight, self.conv6_bias, padding=1, stride=2))
conv6_out_drop = F.dropout(conv6_out, .5)
conv7_out = F.relu(F.conv2d(conv6_out_drop, self.conv7_weight, self.conv7_bias, padding=1))
conv8_out = F.relu(F.conv2d(conv7_out, self.conv8_weight, self.conv8_bias))
class_out = F.relu(F.conv2d(conv8_out, self.class_conv_weight, self.class_conv_bias))
pool_out = F.adaptive_avg_pool2d(class_out, 1)
pool_out.squeeze_(-1)
pool_out.squeeze_(-1)
return pool_out
class HyperNN(AllConvNet):
def __init__(self, input_size, n_classes, embed_size, embedW_init_scheme,
hyperWi_init_scheme, hyperWout_init_scheme, device):
super().__init__(input_size, n_classes)
# Initialize the fixed parameters
self.num_kernels = input_size + 2*n_classes + 1920 # 96*2 + 192 + (192*2)*4
self.weight_embeddings = embedW_init_scheme(torch.zeros(self.num_kernels, embed_size).to(device))
# Initialize the trainable weight parameters
Wi = torch.zeros(self.num_kernels, embed_size, embed_size)
for i in range(self.num_kernels):
Wi[i] = hyperWi_init_scheme(i)(Wi[i])
Bi = torch.zeros(self.num_kernels, embed_size)
Wout = hyperWout_init_scheme(torch.zeros(96*9, embed_size))
Bout = torch.zeros(96*9)
# Register the trainable weight parameters
self.Wi = nn.Parameter(Wi)
self.Bi = nn.Parameter(Bi)
self.Wout = nn.Parameter(Wout)
self.Bout = nn.Parameter(Bout)
# Initialize and register the trainable bias parameters
self.conv1_bias = nn.Parameter(torch.zeros(96))
self.conv2_bias = nn.Parameter(torch.zeros(96))
self.conv3_bias = nn.Parameter(torch.zeros(96))
self.conv4_bias = nn.Parameter(torch.zeros(192))
self.conv5_bias = nn.Parameter(torch.zeros(192))
self.conv6_bias = nn.Parameter(torch.zeros(192))
self.conv7_bias = nn.Parameter(torch.zeros(192))
self.conv8_bias = nn.Parameter(torch.zeros(192))
self.class_conv_bias = nn.Parameter(torch.zeros(n_classes))
def forward(self, x):
# Generate main weights from HyperNet's parameters
idx = 0; jump = self.input_size
self.conv1_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(96, self.input_size, 3, 3)
idx += jump; jump = 96
self.conv2_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(96, 96, 3, 3)
idx += jump;
self.conv3_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(96, 96, 3, 3)
idx += jump; jump = 192
self.conv4_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(192, 96, 3, 3)
idx += jump; jump = 192*2
self.conv5_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(192, 192, 3, 3)
idx += jump;
self.conv6_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(192, 192, 3, 3)
idx += jump;
self.conv7_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(192, 192, 3, 3)
idx += jump;
self.conv8_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(192, 192, 3, 3)[:, :, :1, :1]
idx += jump; jump = self.n_classes * 2
self.class_conv_weight = ((self.Wout @ \
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \
+ self.Bout.unsqueeze(0)).view(self.n_classes, 192, 3, 3)[:, :, :1, :1]
return super().forward(x)
# Configuration
device = 'cuda:0'
embed_size = 50
embedW_init_scheme = embed_uniform
hyperWi_init_scheme = hyperfaninWi_init # hyperfanoutWi_init
hyperWout_init_scheme = fanin_uniform
lr = 0.0005
training_batch_size = 100
test_batch_size = 1000
epochs = 500
log_interval = 100
seed = 123
torch.manual_seed(seed)
train_criterion = nn.CrossEntropyLoss(reduction='mean')
test_criterion = nn.CrossEntropyLoss(reduction='sum')
# Data for Loss Plots
train_loss_list = []
test_loss_list = []
test_acc_list = []
# Training/Testing Functions
def train(model, device, train_loader, optimizer, epoch, log_interval, lr_scheduler):
model.train()
total_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = train_criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
total_loss /= len(train_loader.dataset)
train_loss_list.append(total_loss)
lr_scheduler.step()
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += test_criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
test_loss_list.append(test_loss)
test_acc = 100. * correct / len(test_loader.dataset)
test_acc_list.append(test_acc)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
test_acc))
# CIFAR10 Data Loaders
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=training_batch_size, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=2)
# Model and Optimizer
model = HyperNN(hardcoded_input_size, hardcoded_n_classes, embed_size, embedW_init_scheme,
hyperWi_init_scheme, hyperWout_init_scheme, device).to(device)
num_params = sum([param.numel() for param in model.parameters()])
print("number of params:", num_params)
optimizer = optim.SGD(model.parameters(), lr=lr)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[350,450], gamma=0.1)
# Actual HyperNet Training
for epoch in range(1, epochs + 1):
train(model, device, train_loader, optimizer, epoch, log_interval, lr_scheduler)
test(model, device, test_loader)
# Save Experiment
result_dict = {'train_loss_list': np.array(train_loss_list),
'test_loss_list': np.array(test_loss_list),
'test_acc_list': np.array(test_acc_list)
}
torch.save(result_dict, 'results.dict')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment