Created
December 14, 2017 14:39
-
-
Save jiweibo/dd2d4f21fe4dcf4404c0b7b271c32afa to your computer and use it in GitHub Desktop.
simple ResNet Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.datasets as datasets | |
import torchvision.transforms as transforms | |
from torch.autograd import Variable | |
from torch.utils.data import DataLoader | |
import time | |
from torch.backends import cudnn | |
cudnn.benchmark = True | |
use_cuda = True if torch.cuda.is_available() else False | |
# Image Preprocessing | |
transform = transforms.Compose([ | |
transforms.Resize(40), | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomCrop(32), | |
transforms.ToTensor() | |
]) | |
# CIFAR-10 Dataset | |
train_dataset = datasets.CIFAR10(root=r'E:\DataSets\cifar10', train=True, | |
transform=transform, download=True) | |
test_dataset = datasets.CIFAR10(root=r'E:\DataSets\cifar10', train=False, | |
transform=transforms.ToTensor()) | |
# Data Loader | |
train_loader = DataLoader(dataset=train_dataset, batch_size=100, shuffle=True) | |
test_loader = DataLoader(dataset=test_dataset, batch_size=100, shuffle=False) | |
# 3x3 Convolution | |
def conv3x3(in_channels, out_channels, stride=1): | |
"""3x3 convolution with padding""" | |
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, stride=1, downsample=None): | |
super(ResidualBlock, self).__init__() | |
self.conv1 = conv3x3(in_channels, out_channels, stride) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv2 = conv3x3(out_channels, out_channels) | |
self.bn2 = nn.BatchNorm2d(out_channels) | |
self.downsample = downsample | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
if self.downsample: | |
residual = self.downsample(x) | |
out += residual | |
out = self.relu(out) | |
return out | |
# ResNet Model | |
class ResNet(nn.Module): | |
def __init__(self, block, layers, num_classes=10): | |
super(ResNet, self).__init__() | |
self.in_channels = 16 | |
self.conv = conv3x3(3, 16) | |
self.bn = nn.BatchNorm2d(16) | |
self.relu = nn.ReLU(inplace=True) | |
self.layer1 = self.make_layer(block, 16, layers[0]) | |
self.layer2 = self.make_layer(block, 32, layers[1], 2) | |
self.layer3 = self.make_layer(block, 64, layers[2], 2) | |
self.avg_pool = nn.AvgPool2d(8) | |
self.fc = nn.Linear(64, num_classes) | |
def make_layer(self, block, out_channels, blocks, stride=1): | |
downsample = None | |
if (stride != 1) or (self.in_channels != out_channels): | |
downsample = nn.Sequential( | |
conv3x3(self.in_channels, out_channels, stride), | |
nn.BatchNorm2d(out_channels) | |
) | |
layers = [] | |
layers.append(block(self.in_channels, out_channels, stride, downsample)) | |
self.in_channels = out_channels | |
for i in range(1, blocks): | |
layers.append(block(out_channels, out_channels)) | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
out = self.conv(x) | |
out = self.bn(out) | |
out = self.relu(out) # 32 x 32 x 16 | |
out = self.layer1(out) # 32 x 32 x 16 | |
out = self.layer2(out) # 16 x 16 x 32 | |
out = self.layer3(out) # 8 x 8 x 64 | |
out = self.avg_pool(out) # 1 x 1 x 64 | |
out = out.view(out.size(0), -1) # None x 64 | |
out = self.fc(out) # None x 10 | |
return out | |
resnet = ResNet(ResidualBlock, [2, 2, 2, 2]) | |
# Loss and Criterion | |
criterion = nn.CrossEntropyLoss() | |
lr = 0.001 | |
optimizer = torch.optim.Adam(resnet.parameters(), lr) | |
if use_cuda: | |
resnet = resnet.cuda() | |
criterion = criterion.cuda() | |
# Training | |
print('\nstart train') | |
strat_time = time.time() | |
for epoch in range(80): | |
for i, (images, labels) in enumerate(train_loader): | |
images = Variable(images) | |
labels = Variable(labels) | |
if use_cuda: | |
images = images.cuda() | |
labels = labels.cuda() | |
# Forward + Backward + Optimize | |
optimizer.zero_grad() | |
outputs = resnet(images) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
if(i+1) % 100 == 0: | |
print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f' % (epoch+1, 80, i+1, 500, loss.data[0])) | |
# Deacying Learning Rate | |
if (epoch + 1) % 20 == 0: | |
lr /= 3 | |
optimizer = torch.optim.Adam(resnet.parameters(), lr=lr) | |
print('\nend train, total time', time.time() - strat_time) | |
# Test | |
correct = 0 | |
total = 0 | |
for images, labels in test_loader: | |
images = Variable(images) | |
if use_cuda: | |
images = images.cuda() | |
outputs = resnet(images) | |
_, predicted = torch.max(outputs.cpu().data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum() | |
print('Accuracy of the model on the test images: %d %%' % (100 * correct/total)) | |
# Save the model | |
torch.save(resnet.state_dict(), 'resnet.pkl') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment