Created
June 25, 2019 02:17
-
-
Save DuaneNielsen/87f5c4a35b99b7ddd8003994cd3874ae to your computer and use it in GitHub Desktop.
Implementation of Resnet trained with Fixup. No BatchNorm!
https://arxiv.org/abs/1901.09321
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torchvision.datasets import CIFAR10 | |
from torch.utils.data import DataLoader, random_split, Dataset, ConcatDataset | |
from tqdm import tqdm | |
from torch.optim import Adam, SGD | |
from torchvision.transforms import * | |
import statistics as stats | |
from pathlib import Path | |
from torch.nn.functional import avg_pool2d | |
from colorama import Style, Fore | |
def precision(confusion): | |
correct = confusion * torch.eye(confusion.shape[0]) | |
incorrect = confusion - correct | |
correct = correct.sum(0) | |
incorrect = incorrect.sum(0) | |
precision = correct / (correct + incorrect) | |
total_correct = correct.sum().item() | |
total_incorrect = incorrect.sum().item() | |
percent_correct = total_correct / (total_correct + total_incorrect) | |
return precision, percent_correct | |
class FixupResLayer(nn.Module): | |
def __init__(self, depth, in_layers, filters, stride=1): | |
super().__init__() | |
self.c1 = nn.Conv2d(in_layers, filters, 3, stride=stride, padding=1, bias=False) | |
self.c1.weight.data.mul_(depth ** -0.5) | |
self.c2 = nn.Conv2d(filters, filters, 3, stride=1, padding=1, bias=False) | |
self.c2.weight.data.zero_() | |
self.stride = stride | |
self.gain = nn.Parameter(torch.ones(1)) | |
self.bias = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(4)]) | |
def forward(self, input): | |
hidden = input + self.bias[0] | |
hidden = self.c1(hidden) + self.bias[1] | |
hidden = torch.relu(hidden) + self.bias[2] | |
hidden = self.c2(hidden) * self.gain + self.bias[3] | |
# pad the image if its size is not divisible by 2 | |
padding_h = 0 if input.size(2) % 2 == 0 else 1 | |
padding_w = 0 if input.size(3) % 2 == 0 else 1 | |
id = avg_pool2d(input, self.stride, stride=self.stride, padding=(padding_h, padding_w)) | |
# this assumes we are always doubling the amount of kernels as we go deeper | |
if id.size(1) != hidden.size(1): | |
id = torch.cat((id, id), dim=1) | |
return torch.relu(hidden + id) | |
class DeepResNetFixup(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.first = nn.Conv2d(3, 64, 3, bias=False) | |
self.layer1 = nn.Sequential(*[FixupResLayer(2, 64, 64), FixupResLayer(3, 64, 64)]) | |
self.layer2 = nn.Sequential(*[FixupResLayer(4, 64, 128, stride=2), FixupResLayer(5, 128, 128, stride=2)]) | |
self.layer3 = nn.Sequential(*[FixupResLayer(6, 128, 256, stride=2), FixupResLayer(7, 256, 256, stride=2)]) | |
self.layer4 = nn.Sequential(*[FixupResLayer(8, 256, 512, stride=2), FixupResLayer(9, 512, 512, stride=2)]) | |
self.pool = nn.AvgPool2d(4) | |
self.out = nn.Linear(512, 10) | |
self.out.weight.data.zero_() | |
self.out.bias.data.zero_() | |
def forward(self, input): | |
hidden = torch.relu(self.first(input)) | |
hidden = self.layer1(hidden) | |
hidden = self.layer2(hidden) | |
hidden = self.layer3(hidden) | |
hidden = self.layer4(hidden) | |
hidden = self.pool(hidden).squeeze() | |
return self.out(hidden) | |
if __name__ == '__main__': | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
batch_size = 128 | |
num_classes = 10 | |
fully_supervised = False | |
reload = 169 | |
run_id = 6 | |
epochs = 100 | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
]) | |
# image size 3, 32, 32 | |
# batch size must be an even number | |
# shuffle must be True | |
ds = CIFAR10(r'c:\data\tv', download=True, transform=transform) | |
len_train = len(ds) // 10 * 9 | |
len_test = len(ds) - len_train | |
train, test = random_split(ds, [len_train, len_test]) | |
train_l = DataLoader(train, batch_size=batch_size, shuffle=True, drop_last=True) | |
test_l = DataLoader(test, batch_size=batch_size, shuffle=True, drop_last=True) | |
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') | |
classifier = DeepResNetFixup().to('cuda') | |
#optim = Adam(classifier.parameters(), lr=1e-3, weight_decay=5e-4) | |
optim = SGD(classifier.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) | |
criterion = nn.CrossEntropyLoss() | |
for epoch in range(reload + 1, reload + epochs): | |
ll = [] | |
batch = tqdm(train_l, total=len_train // batch_size) | |
for x, target in batch: | |
x = x.to(device) | |
target = target.to(device) | |
optim.zero_grad() | |
y = classifier(x) | |
loss = criterion(y, target) | |
loss.backward() | |
optim.step() | |
ll.append(loss.item()) | |
batch.set_description(f'{epoch} Train Loss: {stats.mean(ll)}') | |
confusion = torch.zeros(num_classes, num_classes) | |
batch = tqdm(test_l, total=len_test // batch_size) | |
ll = [] | |
for x, target in batch: | |
x = x.to(device) | |
target = target.to(device) | |
y = classifier(x) | |
loss = criterion(y, target) | |
ll.append(loss.detach().item()) | |
batch.set_description(f'{epoch} Test Loss: {stats.mean(ll)}') | |
_, predicted = y.detach().max(1) | |
for item in zip(predicted, target): | |
confusion[item[0], item[1]] += 1 | |
precis, ave_precis = precision(confusion) | |
print('') | |
for i, cls in enumerate(classes): | |
print(f'{Fore.LIGHTMAGENTA_EX}{cls} : {precis[i].item()}{Style.RESET_ALL}') | |
print(f'{Fore.GREEN}ave precision : {ave_precis}{Style.RESET_ALL}') | |
classifier_save_path = Path('~/data/deepinfomax/models/run' + str(run_id) + '/w_dim' + str(epoch) + '.mdl') | |
classifier_save_path.parent.mkdir(parents=True, exist_ok=True) | |
torch.save(classifier, str(classifier_save_path)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment