Created
August 20, 2018 20:14
-
-
Save koshian2/f1ecf57390d5efe24f6d67f3e596b43b to your computer and use it in GitHub Desktop.
DenseNet CIFAR10 in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision.transforms as transforms | |
import time | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
torch.backends.cudnn.benchmark=True | |
# Data | |
transform_train = transforms.Compose([ | |
transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
transform_test = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) | |
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True) | |
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) | |
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False) | |
# Model | |
class DenseBlock(nn.Module): | |
def __init__(self, input_channels, growth_rate): | |
super().__init__() | |
self.input_channels = input_channels | |
self.output_channels = input_channels + growth_rate | |
# Layers | |
self.bn1 = nn.BatchNorm2d(input_channels) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=1) | |
self.bn2 = nn.BatchNorm2d(128) | |
self.conv2 = nn.Conv2d(128, growth_rate, kernel_size=3, padding=1) | |
def forward(self, x): | |
out = self.conv1(self.relu(self.bn1(x))) | |
out = self.conv2(self.relu(self.bn2(out))) | |
return torch.cat([x, out], 1) | |
class TransitionBlock(nn.Module): | |
def __init__(self, input_channels, compression): | |
super().__init__() | |
self.input_channels = input_channels | |
self.output_channels = int(input_channels * compression) | |
# Layers | |
self.bn1 = nn.BatchNorm2d(input_channels) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv1 = nn.Conv2d(input_channels, self.output_channels, kernel_size=1) | |
def forward(self, x): | |
out = self.conv1(self.relu(self.bn1(x))) | |
return nn.AvgPool2d(kernel_size=2)(out) | |
class DenseNet(nn.Module): | |
def __init__(self, growth_rate, compression_factor=0.5, blocks=[1,2,4,3]): | |
super().__init__() | |
# 成長率(growth_rate):DenseBlockで増やすフィルターの数 | |
self.k = growth_rate | |
# 圧縮率(compression_factor):Transitionレイヤーで圧縮するフィルターの比 | |
self.compression = compression_factor | |
# ブロック構成 | |
self.blocks = blocks | |
# 履歴 | |
self.history = {"loss":[], "acc":[], "val_loss":[], "val_acc":[], "time":[]} | |
# モデル作成 | |
self.make_model() | |
def make_dense_block(self, input_channels, nb_blocks): | |
n_channels = input_channels | |
layers = [] | |
for i in range(nb_blocks): | |
item = DenseBlock(n_channels, self.k) | |
layers.append(item) | |
n_channels = item.output_channels | |
return nn.Sequential(*layers), n_channels | |
def make_transition_block(self, input_channels): | |
item = TransitionBlock(input_channels, self.compression) | |
return item, item.output_channels | |
def make_model(self): | |
# blocks=[6,12,24,16]とするとDenseNet-121の設定に準じる | |
# 端数を出さないようにフィルター数16にする | |
n = 16 | |
self.conv1 = nn.Conv2d(3, n, kernel_size=1) | |
# DenseBlock - TransitionLayer - DenseBlock… | |
self.dense1, n = self.make_dense_block(n, self.blocks[0]) | |
self.trans1, n = self.make_transition_block(n) | |
self.dense2, n = self.make_dense_block(n, self.blocks[1]) | |
self.trans2, n = self.make_transition_block(n) | |
self.dense3, n = self.make_dense_block(n, self.blocks[2]) | |
self.trans3, n = self.make_transition_block(n) | |
self.dense4, n = self.make_dense_block(n, self.blocks[3]) | |
self.gap = nn.AvgPool2d(kernel_size=4) # 最後は(4,4) | |
self.fc = nn.Linear(n, 10) # softmaxは損失関数で | |
self.gap_channels = n | |
# モデルの作成 | |
def forward(self, x): | |
out = self.conv1(x) | |
out = self.dense1(out) | |
out = self.trans1(out) | |
out = self.dense2(out) | |
out = self.trans2(out) | |
out = self.dense3(out) | |
out = self.trans3(out) | |
out = self.dense4(out) | |
out = self.gap(out) | |
out = out.view(-1, self.gap_channels) | |
out = self.fc(out) | |
return out | |
net = DenseNet(16, blocks=[6,12,24,16]) | |
if device=="cuda": | |
net = net.cuda() | |
criterion = nn.CrossEntropyLoss() | |
print(net.register_parameter) | |
optimizer = optim.Adam(net.parameters(), lr=0.001, weight_decay=2e-4) | |
nb_epochs = 80 | |
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[nb_epochs*0.5, nb_epochs*0.8], gamma=0.1) | |
# Training | |
def train(epoch): | |
print('\nEpoch: %d' % epoch) | |
net.train() | |
train_loss = 0 | |
correct = 0 | |
total = 0 | |
start_time = time.time() | |
for batch_idx, (inputs, targets) in enumerate(trainloader): | |
inputs, targets = inputs.to(device), targets.to(device) | |
optimizer.zero_grad() | |
outputs = net(inputs) | |
loss = criterion(outputs, targets) | |
loss.backward() | |
optimizer.step() | |
train_loss += loss.item() | |
_, predicted = outputs.max(1) | |
total += targets.size(0) | |
correct += predicted.eq(targets).sum().item() | |
if batch_idx%50 == 0: | |
print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' | |
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) | |
net.history["loss"].append(train_loss/(batch_idx+1)) | |
net.history["acc"].append(1.*correct/total) | |
net.history["time"].append(time.time()-start_time) | |
def validate(epoch): | |
global best_acc | |
net.eval() | |
val_loss = 0 | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for batch_idx, (inputs, targets) in enumerate(testloader): | |
inputs, targets = inputs.to(device), targets.to(device) | |
outputs = net(inputs) | |
loss = criterion(outputs, targets) | |
val_loss += loss.item() | |
_, predicted = outputs.max(1) | |
total += targets.size(0) | |
correct += predicted.eq(targets).sum().item() | |
acc = 100.*correct/total | |
print(batch_idx, len(testloader), 'ValLoss: %.3f | ValAcc: %.3f%% (%d/%d)' | |
% (val_loss/(batch_idx+1), 100.*correct/total, correct, total)) | |
net.history["val_loss"].append(val_loss/(batch_idx+1)) | |
net.history["val_acc"].append(1.*correct/total) | |
# Main-loop | |
for epoch in range(nb_epochs): | |
scheduler.step() | |
train(epoch) | |
validate(epoch) | |
import pickle | |
with open("history.dat", "wb") as fp: | |
pickle.dump(net.history, fp) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment