Skip to content

Instantly share code, notes, and snippets.

Created June 25, 2019 02:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DuaneNielsen/87f5c4a35b99b7ddd8003994cd3874ae to your computer and use it in GitHub Desktop.
Save DuaneNielsen/87f5c4a35b99b7ddd8003994cd3874ae to your computer and use it in GitHub Desktop.
Implementation of Resnet trained with Fixup. No BatchNorm!
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
from import DataLoader, random_split, Dataset, ConcatDataset
from tqdm import tqdm
from torch.optim import Adam, SGD
from torchvision.transforms import *
import statistics as stats
from pathlib import Path
from torch.nn.functional import avg_pool2d
from colorama import Style, Fore
def precision(confusion):
correct = confusion * torch.eye(confusion.shape[0])
incorrect = confusion - correct
correct = correct.sum(0)
incorrect = incorrect.sum(0)
precision = correct / (correct + incorrect)
total_correct = correct.sum().item()
total_incorrect = incorrect.sum().item()
percent_correct = total_correct / (total_correct + total_incorrect)
return precision, percent_correct
class FixupResLayer(nn.Module):
def __init__(self, depth, in_layers, filters, stride=1):
self.c1 = nn.Conv2d(in_layers, filters, 3, stride=stride, padding=1, bias=False) ** -0.5)
self.c2 = nn.Conv2d(filters, filters, 3, stride=1, padding=1, bias=False)
self.stride = stride
self.gain = nn.Parameter(torch.ones(1))
self.bias = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(4)])
def forward(self, input):
hidden = input + self.bias[0]
hidden = self.c1(hidden) + self.bias[1]
hidden = torch.relu(hidden) + self.bias[2]
hidden = self.c2(hidden) * self.gain + self.bias[3]
# pad the image if its size is not divisible by 2
padding_h = 0 if input.size(2) % 2 == 0 else 1
padding_w = 0 if input.size(3) % 2 == 0 else 1
id = avg_pool2d(input, self.stride, stride=self.stride, padding=(padding_h, padding_w))
# this assumes we are always doubling the amount of kernels as we go deeper
if id.size(1) != hidden.size(1):
id =, id), dim=1)
return torch.relu(hidden + id)
class DeepResNetFixup(nn.Module):
def __init__(self):
self.first = nn.Conv2d(3, 64, 3, bias=False)
self.layer1 = nn.Sequential(*[FixupResLayer(2, 64, 64), FixupResLayer(3, 64, 64)])
self.layer2 = nn.Sequential(*[FixupResLayer(4, 64, 128, stride=2), FixupResLayer(5, 128, 128, stride=2)])
self.layer3 = nn.Sequential(*[FixupResLayer(6, 128, 256, stride=2), FixupResLayer(7, 256, 256, stride=2)])
self.layer4 = nn.Sequential(*[FixupResLayer(8, 256, 512, stride=2), FixupResLayer(9, 512, 512, stride=2)])
self.pool = nn.AvgPool2d(4)
self.out = nn.Linear(512, 10)
def forward(self, input):
hidden = torch.relu(self.first(input))
hidden = self.layer1(hidden)
hidden = self.layer2(hidden)
hidden = self.layer3(hidden)
hidden = self.layer4(hidden)
hidden = self.pool(hidden).squeeze()
return self.out(hidden)
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128
num_classes = 10
fully_supervised = False
reload = 169
run_id = 6
epochs = 100
transform = transforms.Compose([
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
# image size 3, 32, 32
# batch size must be an even number
# shuffle must be True
ds = CIFAR10(r'c:\data\tv', download=True, transform=transform)
len_train = len(ds) // 10 * 9
len_test = len(ds) - len_train
train, test = random_split(ds, [len_train, len_test])
train_l = DataLoader(train, batch_size=batch_size, shuffle=True, drop_last=True)
test_l = DataLoader(test, batch_size=batch_size, shuffle=True, drop_last=True)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
classifier = DeepResNetFixup().to('cuda')
#optim = Adam(classifier.parameters(), lr=1e-3, weight_decay=5e-4)
optim = SGD(classifier.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(reload + 1, reload + epochs):
ll = []
batch = tqdm(train_l, total=len_train // batch_size)
for x, target in batch:
x =
target =
y = classifier(x)
loss = criterion(y, target)
batch.set_description(f'{epoch} Train Loss: {stats.mean(ll)}')
confusion = torch.zeros(num_classes, num_classes)
batch = tqdm(test_l, total=len_test // batch_size)
ll = []
for x, target in batch:
x =
target =
y = classifier(x)
loss = criterion(y, target)
batch.set_description(f'{epoch} Test Loss: {stats.mean(ll)}')
_, predicted = y.detach().max(1)
for item in zip(predicted, target):
confusion[item[0], item[1]] += 1
precis, ave_precis = precision(confusion)
for i, cls in enumerate(classes):
print(f'{Fore.LIGHTMAGENTA_EX}{cls} : {precis[i].item()}{Style.RESET_ALL}')
print(f'{Fore.GREEN}ave precision : {ave_precis}{Style.RESET_ALL}')
classifier_save_path = Path('~/data/deepinfomax/models/run' + str(run_id) + '/w_dim' + str(epoch) + '.mdl')
classifier_save_path.parent.mkdir(parents=True, exist_ok=True), str(classifier_save_path))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment