Skip to content

Instantly share code, notes, and snippets.

@xmfbit
Last active March 5, 2023 17:03
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmfbit/67c407e34cbaf56e7820f09e774e56d8 to your computer and use it in GitHub Desktop.
Save xmfbit/67c407e34cbaf56e7820f09e774e56d8 to your computer and use it in GitHub Desktop.
ResNet-164 training experiment on CIFAR10 using PyTorch, see the paper: Identity Mappings in Deep Residual Networks
import torch
import torch.nn as nn
import math
## the model definition
# see HeKaiming's implementation using torch:
# https://github.com/KaimingHe/resnet-1k-layers/blob/master/README.md
class Bottleneck(nn.Module):
expansion = 4 # # output cahnnels / # input channels
def __init__(self, inplanes, outplanes, stride=1):
assert outplanes % self.expansion == 0
super(Bottleneck, self).__init__()
self.inplanes = inplanes
self.outplanes = outplanes
self.bottleneck_planes = outplanes / self.expansion
self.stride = stride
self._make_layer()
def _make_layer(self):
# conv 1x1
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.conv1 = nn.Conv2d(self.inplanes, self.bottleneck_planes,
kernel_size=1, stride=self.stride, bias=False)
# conv 3x3
self.bn2 = nn.BatchNorm2d(self.bottleneck_planes)
self.conv2 = nn.Conv2d(self.bottleneck_planes, self.bottleneck_planes,
kernel_size=3, stride=1, padding=1, bias=False)
# conv 1x1
self.bn3 = nn.BatchNorm2d(self.bottleneck_planes)
self.conv3 = nn.Conv2d(self.bottleneck_planes, self.outplanes, kernel_size=1,
stride=1)
if self.inplanes != self.outplanes:
self.shortcut = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=1,
stride=self.stride, bias=False)
else:
self.shortcut = None
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
# we do pre-activation
out = self.relu(self.bn1(x))
out = self.conv1(out)
out = self.relu(self.bn2(out))
out = self.conv2(out)
out = self.relu(self.bn3(out))
out = self.conv3(out)
if self.shortcut is not None:
residual = self.shortcut(residual)
out += residual
return out
class ResNet(nn.Module):
def __init__(self, block, depth, output_classes=1000):
assert (depth - 2) % 9 == 0 # 164 or 1001
super(ResNet, self).__init__()
n = (depth - 2) / 9
nstages = [16, 64, 128, 256]
# one conv at the beginning (spatial size: 32x32)
self.conv1 = nn.Conv2d(3, nstages[0], kernel_size=3, stride=1,
padding=1, bias=False)
# use `block` as unit to construct res-net
# Stage 0 (spatial size: 32x32)
self.layer1 = self._make_layer(block, nstages[0], nstages[1], n)
# Stage 1 (spatial size: 32x32)
self.layer2 = self._make_layer(block, nstages[1], nstages[2], n, stride=2)
# Stage 2 (spatial size: 16x16)
self.layer3 = self._make_layer(block, nstages[2], nstages[3], n, stride=2)
# Stage 3 (spatial size: 8x8)
self.bn = nn.BatchNorm2d(nstages[3])
self.relu = nn.ReLU(inplace=True)
# classifier
self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(nstages[3], output_classes)
# weight initialization
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, inplanes, outplanes, nstage, stride=1):
layers = []
layers.append(block(inplanes, outplanes, stride))
for i in range(1, nstage):
layers.append(block(outplanes, outplanes, stride=1))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.relu(self.bn(x))
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def resnet_164(output_classes):
model = ResNet(Bottleneck, 164, output_classes)
return model
## training script for CIFAR10
import os, shutil, time
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import tensorboard
from model import resnet_164
CIFAR10_DIR = '/data/'
WORKERS = 4
BATCH_SIZE = 128
USE_CUDA = torch.cuda.is_available()
MAX_EPOCH = 150
PRINT_FREQUENCY = 100
if USE_CUDA:
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
# load data
if not os.path.exists(CIFAR10_DIR):
raise RuntimeError('Cannot find CIFAR10 directory')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_set = CIFAR10(root=CIFAR10_DIR, train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop((32, 32), 4),
transforms.ToTensor(), normalize]))
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
num_workers=WORKERS, pin_memory=True)
val_loader = DataLoader(CIFAR10(root=CIFAR10_DIR, train=False, transform=
transforms.Compose([
transforms.ToTensor(), normalize])),
batch_size=BATCH_SIZE, shuffle=False,
num_workers=WORKERS, pin_memory=True)
# get resnet-164
def get_model():
model = resnet_164(output_classes=10)
if USE_CUDA:
model = model.cuda()
return model
# remove existing log directory
def remove_log():
if os.path.exists('./log'):
shutil.rmtree('./log')
os.mkdir('./log')
# Metric
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
# top-k accuracy
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
# validation
def validate(model, ceriterion):
model.eval()
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
end = time.time()
for ind, (x, label) in enumerate(val_loader):
if USE_CUDA:
x, label = x.cuda(), label.cuda()
vx, vl = Variable(x, volatile=True), Variable(label, volatile=True)
score = model(vx)
loss = ceriterion(score, vl)
prec1 = accuracy(score.data, label)
losses.update(loss.data[0], x.size(0))
top1.update(prec1[0][0], x.size(0))
batch_time.update(time.time() - end)
end = time.time()
print('Test: [{0}/{0}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'.format(
len(val_loader), batch_time=batch_time, loss=losses, top1=top1))
return top1.avg, losses.avg
# train
def train(model):
remove_log()
writer = tensorboard.SummaryWriter('./log')
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9,
weight_decay=0.0001)
ceriterion = nn.CrossEntropyLoss()
step = 1
for epoch in range(1, MAX_EPOCH + 1):
if epoch == 80 or epoch == 120:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1
data_time = AverageMeter()
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
model.train()
end = time.time()
for ind, (x, label) in enumerate(train_loader):
data_time.update(time.time()-end)
if USE_CUDA:
x, label = x.cuda(), label.cuda()
vx, vl = Variable(x), Variable(label)
score = model(vx)
loss = ceriterion(score, vl)
optimizer.zero_grad()
loss.backward()
optimizer.step()
step += 1
batch_time.update(time.time()-end)
prec1 = accuracy(score.data, label)
losses.update(loss.data[0], x.size(0))
top1.update(prec1[0][0], x.size(0))
writer.add_scalar('train_loss', loss.data[0], step)
writer.add_scalar('train_acc', prec1[0][0], step)
if (ind+1) % PRINT_FREQUENCY == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'.format(
epoch, ind+1, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1))
end = time.time()
top1, test_loss = validate(model, ceriterion)
writer.add_scalar('test_loss', test_loss, step)
writer.add_scalar('test_acc', top1, step)
if epoch % 30 == 0:
torch.save({'state_dcit': model.state_dict(),
'accuracy': top1},
'epoch-{:03d}-model.pth.tar'.format(epoch))
if __name__ == '__main__':
model = get_model()
train(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment