Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created August 20, 2018 20:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save koshian2/f1ecf57390d5efe24f6d67f3e596b43b to your computer and use it in GitHub Desktop.
Save koshian2/f1ecf57390d5efe24f6d67f3e596b43b to your computer and use it in GitHub Desktop.
DenseNet CIFAR10 in PyTorch
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import time
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.backends.cudnn.benchmark=True
# Data
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)
# Model
class DenseBlock(nn.Module):
def __init__(self, input_channels, growth_rate):
super().__init__()
self.input_channels = input_channels
self.output_channels = input_channels + growth_rate
# Layers
self.bn1 = nn.BatchNorm2d(input_channels)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=1)
self.bn2 = nn.BatchNorm2d(128)
self.conv2 = nn.Conv2d(128, growth_rate, kernel_size=3, padding=1)
def forward(self, x):
out = self.conv1(self.relu(self.bn1(x)))
out = self.conv2(self.relu(self.bn2(out)))
return torch.cat([x, out], 1)
class TransitionBlock(nn.Module):
def __init__(self, input_channels, compression):
super().__init__()
self.input_channels = input_channels
self.output_channels = int(input_channels * compression)
# Layers
self.bn1 = nn.BatchNorm2d(input_channels)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(input_channels, self.output_channels, kernel_size=1)
def forward(self, x):
out = self.conv1(self.relu(self.bn1(x)))
return nn.AvgPool2d(kernel_size=2)(out)
class DenseNet(nn.Module):
def __init__(self, growth_rate, compression_factor=0.5, blocks=[1,2,4,3]):
super().__init__()
# 成長率(growth_rate):DenseBlockで増やすフィルターの数
self.k = growth_rate
# 圧縮率(compression_factor):Transitionレイヤーで圧縮するフィルターの比
self.compression = compression_factor
# ブロック構成
self.blocks = blocks
# 履歴
self.history = {"loss":[], "acc":[], "val_loss":[], "val_acc":[], "time":[]}
# モデル作成
self.make_model()
def make_dense_block(self, input_channels, nb_blocks):
n_channels = input_channels
layers = []
for i in range(nb_blocks):
item = DenseBlock(n_channels, self.k)
layers.append(item)
n_channels = item.output_channels
return nn.Sequential(*layers), n_channels
def make_transition_block(self, input_channels):
item = TransitionBlock(input_channels, self.compression)
return item, item.output_channels
def make_model(self):
# blocks=[6,12,24,16]とするとDenseNet-121の設定に準じる
# 端数を出さないようにフィルター数16にする
n = 16
self.conv1 = nn.Conv2d(3, n, kernel_size=1)
# DenseBlock - TransitionLayer - DenseBlock…
self.dense1, n = self.make_dense_block(n, self.blocks[0])
self.trans1, n = self.make_transition_block(n)
self.dense2, n = self.make_dense_block(n, self.blocks[1])
self.trans2, n = self.make_transition_block(n)
self.dense3, n = self.make_dense_block(n, self.blocks[2])
self.trans3, n = self.make_transition_block(n)
self.dense4, n = self.make_dense_block(n, self.blocks[3])
self.gap = nn.AvgPool2d(kernel_size=4) # 最後は(4,4)
self.fc = nn.Linear(n, 10) # softmaxは損失関数で
self.gap_channels = n
# モデルの作成
def forward(self, x):
out = self.conv1(x)
out = self.dense1(out)
out = self.trans1(out)
out = self.dense2(out)
out = self.trans2(out)
out = self.dense3(out)
out = self.trans3(out)
out = self.dense4(out)
out = self.gap(out)
out = out.view(-1, self.gap_channels)
out = self.fc(out)
return out
net = DenseNet(16, blocks=[6,12,24,16])
if device=="cuda":
net = net.cuda()
criterion = nn.CrossEntropyLoss()
print(net.register_parameter)
optimizer = optim.Adam(net.parameters(), lr=0.001, weight_decay=2e-4)
nb_epochs = 80
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[nb_epochs*0.5, nb_epochs*0.8], gamma=0.1)
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
start_time = time.time()
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if batch_idx%50 == 0:
print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
net.history["loss"].append(train_loss/(batch_idx+1))
net.history["acc"].append(1.*correct/total)
net.history["time"].append(time.time()-start_time)
def validate(epoch):
global best_acc
net.eval()
val_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
acc = 100.*correct/total
print(batch_idx, len(testloader), 'ValLoss: %.3f | ValAcc: %.3f%% (%d/%d)'
% (val_loss/(batch_idx+1), 100.*correct/total, correct, total))
net.history["val_loss"].append(val_loss/(batch_idx+1))
net.history["val_acc"].append(1.*correct/total)
# Main-loop
for epoch in range(nb_epochs):
scheduler.step()
train(epoch)
validate(epoch)
import pickle
with open("history.dat", "wb") as fp:
pickle.dump(net.history, fp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment