Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created September 16, 2019 07:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/7b850e453bd9d35ecd5d073876a150b5 to your computer and use it in GitHub Desktop.
Save koshian2/7b850e453bd9d35ecd5d073876a150b5 to your computer and use it in GitHub Desktop.
Normalization vs gradients
import torch
import torchvision
from models import TenLayersModel, ResNetLikeModel
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import os
import pickle
def load_cifar():
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=trans)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=trans)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)
return trainloader, testloader
def calc_gradients(model, input_x, target_y):
model.zero_grad()
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(model(input_x), target_y)
loss.backward()
grads = [p.grad.norm().item() for p in model.parameters() if len(p.size()) == 4]
return grads
def main(network, normalization):
if network == "ten":
model = TenLayersModel(normalization)
elif network == "resnet":
model = ResNetLikeModel(normalization)
model_name = f"{network}_{normalization}"
output_dir = "snapshot"
if not os.path.exists(output_dir):
os.mkdir(output_dir)
device = "cuda"
batch_size = 128
model.to(device)
model = torch.nn.DataParallel(model)
trainloader, testloader = load_cifar()
log_gradients = []
log_loss = []
log_val_acc = []
weight_decay = 0.0001 if network == "resnet" else 0
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.1)
criterion = torch.nn.CrossEntropyLoss()
max_val_acc = 0.0
for epoch in tqdm(range(100)):
# gradient check
gradients = []
for X, y in testloader:
if len(X) != 128: continue
X, y = X.to(device), y.to(device)
gradients.append(calc_gradients(model, X, y))
model.zero_grad()
layer_grads = np.mean(np.array(gradients), axis=0) # batch-wise mean
log_gradients.append(layer_grads)
# train
train_loss = 0.0
for i, (X, y) in enumerate(trainloader):
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
y_pred = model(X)
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
train_loss += loss.item()
log_loss.append(train_loss / (i + 1))
# validation
with torch.no_grad():
correct, total = 0, 0
for X, y in testloader:
X, y = X.to(device), y.to(device)
outputs = model(X)
_, pred = torch.max(outputs.data, 1)
total += y.size(0)
correct += (pred == y).sum().item()
log_val_acc.append(correct / total)
# save model
if max_val_acc < log_val_acc[-1]:
torch.save(model.state_dict(), f"{output_dir}/{model_name}.pytorch")
max_val_acc = log_val_acc[-1]
scheduler.step()
print("Epoch =", epoch, "Loss =", log_loss[-1], "Val_acc =", log_val_acc[-1], "/ ", model_name)
# print(log_gradients[-1])
# save result
with open(f"{output_dir}/log_{model_name}.pkl", "wb") as fp:
result = {"gradient":log_gradients, "loss":log_loss, "val_acc":log_val_acc}
pickle.dump(result, fp)
if __name__ == "__main__":
for model in ["ten", "resnet"]:
for norm in ["batch", "instance", "spectral"]:
main(model, norm)
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.nn.utils.spectral_norm import spectral_norm
class TenLayersModel(nn.Module):
def __init__(self, normalization):
super().__init__()
self.convs = self.create_model(normalization)
self.linear = nn.Linear(256, 10)
def conv_norm_relu(self, in_ch, out_ch, normalization):
layers = []
if normalization == "spectral":
w = nn.Conv2d(in_ch, out_ch, 3, padding=1)
layers.append(spectral_norm(w))
else:
layers.append(nn.Conv2d(in_ch, out_ch, 3, padding=1))
if normalization == "batch":
layers.append(nn.BatchNorm2d(out_ch))
elif normalization == "instance":
layers.append(nn.InstanceNorm2d(out_ch))
layers.append(nn.ReLU(True))
return layers
def create_model(self, normalization):
layers = []
for in_ch in [3, 64, 64]:
layers += self.conv_norm_relu(in_ch, 64, normalization)
layers.append(nn.AvgPool2d(2))
for in_ch in [64, 128, 128]:
layers += self.conv_norm_relu(in_ch, 128, normalization)
layers.append(nn.AvgPool2d(2))
for in_ch in [128, 256, 256]:
layers += self.conv_norm_relu(in_ch, 256, normalization)
layers.append(nn.AvgPool2d(8))
return nn.Sequential(*layers)
def forward(self, inputs):
x = self.convs(inputs).view(inputs.size(0), -1)
x = self.linear(x)
return x
class ResNetPreactModule(nn.Module):
def __init__(self, in_ch, out_ch, downsampling, normalization):
assert normalization in ["batch", "spectral", "instance"]
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.shortcut_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else None
self.norm1 = self.get_normalization(in_ch, normalization)
self.norm2 = self.get_normalization(out_ch, normalization)
if normalization == "spectral":
self.conv1 = spectral_norm(self.conv1)
self.conv2 = spectral_norm(self.conv2)
self.shortcut_conv = spectral_norm(self.shortcut_conv) if self.shortcut_conv is not None else None
self.downsampling = nn.AvgPool2d(downsampling) if downsampling > 1 else None
def get_normalization(self, ch, normalization):
if normalization == "batch":
return nn.BatchNorm2d(ch)
elif normalization == "instance":
return nn.InstanceNorm2d(ch)
else:
return None
def forward(self, inputs):
# main path
x = self.norm1(inputs) if self.norm1 is not None else inputs
x = F.relu(x)
x = self.conv1(x)
x = self.norm2(x) if self.norm2 is not None else x
x = F.relu(x)
x = self.conv2(x)
# shortcut path
shortcut = self.shortcut_conv(inputs) if self.shortcut_conv is not None else inputs
# downsampling
if self.downsampling is not None:
x = self.downsampling(x)
shortcut = self.downsampling(shortcut)
return x + shortcut
class ResNetLikeModel(nn.Module):
def __init__(self, normalization):
super().__init__()
self.conv = nn.Sequential(
*self.resnet_block(3, 64, 3, normalization),
*self.resnet_block(64, 128, 4, normalization),
*self.resnet_block(128, 256, 6, normalization, enable_downsampling=False),
nn.AvgPool2d(8)
)
if normalization == "batch":
self.last_norm = nn.BatchNorm2d(256)
elif normalization == "instance":
self.last_norm = nn.InstanceNorm2d(256)
else:
self.last_norm = None
self.linear = nn.Linear(256, 10)
def resnet_block(self, in_ch, out_ch, reps, normalization, enable_downsampling=True):
layers = []
for i in range(reps):
current_in = in_ch if i == 0 else out_ch
down = 2 if i == reps-1 and enable_downsampling else 1
layers.append(ResNetPreactModule(current_in, out_ch, down, normalization))
return layers
def forward(self, inputs):
x = self.conv(inputs)
x = self.last_norm(x) if self.last_norm is not None else x
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment