Created
April 28, 2020 15:34
-
-
Save tmk815/47e74d06bff8f7ab74c0590e8bf0d4e1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
import torch | |
import torchvision | |
import torch.nn as nn | |
import torch.nn.init as init | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from torch.optim import lr_scheduler | |
from torchvision import datasets, models, transforms | |
%matplotlib inline | |
import time | |
import os | |
import numpy as np | |
from matplotlib import pyplot as plt | |
#%% | |
print(torch.cuda.is_available()) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
#%% | |
!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip | |
!unzip hymenoptera_data.zip | |
#%% | |
#画像の前処理を定義 | |
data_transforms = { | |
'train': transforms.Compose([ | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
'val': transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
} | |
#%% | |
#画像とラベルを読み込む | |
data_dir = 'hymenoptera_data' | |
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), | |
data_transforms[x]) | |
for x in ['train', 'val']} | |
train_loader = torch.utils.data.DataLoader(image_datasets['train'], batch_size=5, | |
shuffle=True, num_workers=4) | |
test_loader = torch.utils.data.DataLoader(image_datasets['val'], batch_size=5, | |
shuffle=False, num_workers=4) | |
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} | |
class_names = image_datasets['train'].classes | |
#%% | |
for train in train_loader: | |
print(train[0].shape) | |
print(train[0].dtype) | |
break | |
#%% | |
#ネットワークalexnetの定義 | |
net = models.alexnet(pretrained=True) | |
net = net.to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | |
net | |
#%% | |
#ネットワークのパラメータを凍結 | |
for param in net.parameters(): | |
param.requires_grad = False | |
net = net.to(device) | |
#最終層を2クラス用に変更 | |
num_ftrs = net.classifier[6].in_features | |
net.classifier[6] = nn.Linear(num_ftrs, 2).to(device) | |
net | |
#%% | |
#最適化関数 | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | |
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) | |
#%% | |
num_epochs = 15 | |
train_loss_list = [] | |
train_acc_list = [] | |
val_loss_list = [] | |
val_acc_list = [] | |
for epoch in range(num_epochs): | |
train_loss = 0 | |
train_acc = 0 | |
val_loss = 0 | |
val_acc = 0 | |
#train | |
net.train() | |
for i, (images, labels) in enumerate(train_loader): | |
images, labels = images.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = net(images) | |
loss = criterion(outputs, labels) | |
train_loss += loss.item() | |
train_acc += (outputs.max(1)[1] == labels).sum().item() | |
loss.backward() | |
optimizer.step() | |
avg_train_loss = train_loss / len(train_loader.dataset) | |
avg_train_acc = train_acc / len(train_loader.dataset) | |
#val | |
net.eval() | |
with torch.no_grad(): | |
for images, labels in test_loader: | |
images = images.to(device) | |
labels = labels.to(device) | |
outputs = net(images) | |
loss = criterion(outputs, labels) | |
val_loss += loss.item() | |
val_acc += (outputs.max(1)[1] == labels).sum().item() | |
avg_val_loss = val_loss / len(test_loader.dataset) | |
avg_val_acc = val_acc / len(test_loader.dataset) | |
print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}, lr:{learning_rate}' | |
.format(epoch+1, num_epochs, i+1, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc, learning_rate=optimizer.param_groups[0]["lr"])) | |
#学習率調整 | |
lr_scheduler.step() | |
train_loss_list.append(avg_train_loss) | |
train_acc_list.append(avg_train_acc) | |
val_loss_list.append(avg_val_loss) | |
val_acc_list.append(avg_val_acc) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment