Skip to content

Instantly share code, notes, and snippets.

@DuaneNielsen
Created June 25, 2019 02:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DuaneNielsen/87f5c4a35b99b7ddd8003994cd3874ae to your computer and use it in GitHub Desktop.
Save DuaneNielsen/87f5c4a35b99b7ddd8003994cd3874ae to your computer and use it in GitHub Desktop.
Implementation of Resnet trained with Fixup. No BatchNorm! https://arxiv.org/abs/1901.09321
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split, Dataset, ConcatDataset
from tqdm import tqdm
from torch.optim import Adam, SGD
from torchvision.transforms import *
import statistics as stats
from pathlib import Path
from torch.nn.functional import avg_pool2d
from colorama import Style, Fore
def precision(confusion):
correct = confusion * torch.eye(confusion.shape[0])
incorrect = confusion - correct
correct = correct.sum(0)
incorrect = incorrect.sum(0)
precision = correct / (correct + incorrect)
total_correct = correct.sum().item()
total_incorrect = incorrect.sum().item()
percent_correct = total_correct / (total_correct + total_incorrect)
return precision, percent_correct
class FixupResLayer(nn.Module):
def __init__(self, depth, in_layers, filters, stride=1):
super().__init__()
self.c1 = nn.Conv2d(in_layers, filters, 3, stride=stride, padding=1, bias=False)
self.c1.weight.data.mul_(depth ** -0.5)
self.c2 = nn.Conv2d(filters, filters, 3, stride=1, padding=1, bias=False)
self.c2.weight.data.zero_()
self.stride = stride
self.gain = nn.Parameter(torch.ones(1))
self.bias = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(4)])
def forward(self, input):
hidden = input + self.bias[0]
hidden = self.c1(hidden) + self.bias[1]
hidden = torch.relu(hidden) + self.bias[2]
hidden = self.c2(hidden) * self.gain + self.bias[3]
# pad the image if its size is not divisible by 2
padding_h = 0 if input.size(2) % 2 == 0 else 1
padding_w = 0 if input.size(3) % 2 == 0 else 1
id = avg_pool2d(input, self.stride, stride=self.stride, padding=(padding_h, padding_w))
# this assumes we are always doubling the amount of kernels as we go deeper
if id.size(1) != hidden.size(1):
id = torch.cat((id, id), dim=1)
return torch.relu(hidden + id)
class DeepResNetFixup(nn.Module):
def __init__(self):
super().__init__()
self.first = nn.Conv2d(3, 64, 3, bias=False)
self.layer1 = nn.Sequential(*[FixupResLayer(2, 64, 64), FixupResLayer(3, 64, 64)])
self.layer2 = nn.Sequential(*[FixupResLayer(4, 64, 128, stride=2), FixupResLayer(5, 128, 128, stride=2)])
self.layer3 = nn.Sequential(*[FixupResLayer(6, 128, 256, stride=2), FixupResLayer(7, 256, 256, stride=2)])
self.layer4 = nn.Sequential(*[FixupResLayer(8, 256, 512, stride=2), FixupResLayer(9, 512, 512, stride=2)])
self.pool = nn.AvgPool2d(4)
self.out = nn.Linear(512, 10)
self.out.weight.data.zero_()
self.out.bias.data.zero_()
def forward(self, input):
hidden = torch.relu(self.first(input))
hidden = self.layer1(hidden)
hidden = self.layer2(hidden)
hidden = self.layer3(hidden)
hidden = self.layer4(hidden)
hidden = self.pool(hidden).squeeze()
return self.out(hidden)
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128
num_classes = 10
fully_supervised = False
reload = 169
run_id = 6
epochs = 100
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# image size 3, 32, 32
# batch size must be an even number
# shuffle must be True
ds = CIFAR10(r'c:\data\tv', download=True, transform=transform)
len_train = len(ds) // 10 * 9
len_test = len(ds) - len_train
train, test = random_split(ds, [len_train, len_test])
train_l = DataLoader(train, batch_size=batch_size, shuffle=True, drop_last=True)
test_l = DataLoader(test, batch_size=batch_size, shuffle=True, drop_last=True)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
classifier = DeepResNetFixup().to('cuda')
#optim = Adam(classifier.parameters(), lr=1e-3, weight_decay=5e-4)
optim = SGD(classifier.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(reload + 1, reload + epochs):
ll = []
batch = tqdm(train_l, total=len_train // batch_size)
for x, target in batch:
x = x.to(device)
target = target.to(device)
optim.zero_grad()
y = classifier(x)
loss = criterion(y, target)
loss.backward()
optim.step()
ll.append(loss.item())
batch.set_description(f'{epoch} Train Loss: {stats.mean(ll)}')
confusion = torch.zeros(num_classes, num_classes)
batch = tqdm(test_l, total=len_test // batch_size)
ll = []
for x, target in batch:
x = x.to(device)
target = target.to(device)
y = classifier(x)
loss = criterion(y, target)
ll.append(loss.detach().item())
batch.set_description(f'{epoch} Test Loss: {stats.mean(ll)}')
_, predicted = y.detach().max(1)
for item in zip(predicted, target):
confusion[item[0], item[1]] += 1
precis, ave_precis = precision(confusion)
print('')
for i, cls in enumerate(classes):
print(f'{Fore.LIGHTMAGENTA_EX}{cls} : {precis[i].item()}{Style.RESET_ALL}')
print(f'{Fore.GREEN}ave precision : {ave_precis}{Style.RESET_ALL}')
classifier_save_path = Path('~/data/deepinfomax/models/run' + str(run_id) + '/w_dim' + str(epoch) + '.mdl')
classifier_save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(classifier, str(classifier_save_path))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment