Last active
March 16, 2021 06:43
-
-
Save daovietanh190499/d44fb44c22704c323e76fa74784f8d3d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function, division | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.optim import lr_scheduler | |
from torch.autograd import Variable | |
from torchvision import models | |
from torchvision import transforms, datasets | |
from PIL import Image | |
import scipy.io as scio | |
import numpy as np | |
import time | |
import os | |
import argparse | |
import tqdm | |
import sys | |
torch.backends.cudnn.enabled = True | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_split_train_test(args, datadir, valid_size = .2): | |
train_transforms = transforms.Compose([#transforms.RandomRotation(30), # data augmentations are great | |
#transforms.RandomResizedCrop(224), # but not in this case of map tiles | |
#transforms.RandomHorizontalFlip(), | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
#transforms.Normalize([0.485, 0.456, 0.406], # PyTorch recommends these but in this | |
# [0.229, 0.224, 0.225]) # case I didn't get good results | |
]) | |
test_transforms = transforms.Compose([transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
#transforms.Normalize([0.485, 0.456, 0.406], | |
# [0.229, 0.224, 0.225]) | |
]) | |
train_data = datasets.ImageFolder(datadir, transform=train_transforms) | |
test_data = datasets.ImageFolder(datadir, transform=test_transforms) | |
num_train = len(train_data) | |
indices = list(range(num_train)) | |
split = int(np.floor(valid_size * num_train)) | |
np.random.shuffle(indices) | |
from torch.utils.data.sampler import SubsetRandomSampler | |
train_idx, test_idx = indices[split:], indices[:split] | |
train_sampler = SubsetRandomSampler(train_idx) | |
test_sampler = SubsetRandomSampler(test_idx) | |
trainloader = torch.utils.data.DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size) | |
testloader = torch.utils.data.DataLoader(test_data, sampler=test_sampler, batch_size=args.batch_size) | |
return trainloader, testloader | |
def train_model(args, model, criterion, optimizer, scheduler): | |
epochs = (args.start_epoch, args.num_epochs) | |
running_loss = 0 | |
train_loss = 0 | |
running_accuracy = 0 | |
train_accuracy = 0 | |
print_every = args.print_freq | |
train_losses, test_losses = [], [] | |
for epoch in range(epochs[0], epochs[1]): | |
for steps, (inputs, labels) in enumerate(trainloader): | |
inputs, labels = inputs.to(device), labels.to(device) | |
optimizer.zero_grad() | |
logps = model.forward(inputs) | |
loss = criterion(logps, labels) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
train_loss += loss.item() | |
ps = torch.exp(logps) | |
top_p, top_class = ps.topk(1, dim=1) | |
equals = top_class == labels.view(*top_class.shape) | |
running_accuracy += torch.mean(equals.type(torch.FloatTensor)).item() | |
train_accuracy += torch.mean(equals.type(torch.FloatTensor)).item() | |
if steps % print_every == 0: | |
print(f"Epoch [{epoch+1}|{epochs[1]}] " | |
f"Iter [{steps}|{len(trainloader)}] " | |
f"Train loss: {running_loss/print_every:.3f} " | |
f"Train accuracy: {running_accuracy/print_every:.3f}") | |
running_loss = 0 | |
running_accuracy = 0 | |
print("Calculate val loss ...") | |
test_loss = 0 | |
accuracy = 0 | |
model.eval() | |
with torch.no_grad(): | |
for inputs, labels in testloader: | |
inputs, labels = inputs.to(device), labels.to(device) | |
logps = model.forward(inputs) | |
batch_loss = criterion(logps, labels) | |
test_loss += batch_loss.item() | |
ps = torch.exp(logps) | |
top_p, top_class = ps.topk(1, dim=1) | |
equals = top_class == labels.view(*top_class.shape) | |
accuracy += torch.mean(equals.type(torch.FloatTensor)).item() | |
train_losses.append(train_loss/len(trainloader)) | |
test_losses.append(test_loss/len(testloader)) | |
print(f"Epoch [{epoch+1}|{epochs[1]}] " | |
f"Train loss: {train_loss/len(trainloader):.3f} " | |
f"Train accruracy: {train_accuracy/len(trainloader):.3f} " | |
f"Test loss: {test_loss/len(testloader):.3f} " | |
f"Test accuracy: {accuracy/len(testloader):.3f}") | |
model.train() | |
train_loss = 0 | |
train_accuracy = 0 | |
if epoch % args.save_epoch_freq == 0: | |
print("Saving state ...") | |
torch.save(model, args.save_path + 'resnet_epoch_' + str(epoch) +'.pth') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="PyTorch implementation of SENet") | |
parser.add_argument('--data-dir', type=str, default="/ImageNet") | |
parser.add_argument('--batch-size', type=int, default=16) | |
parser.add_argument('--num-class', type=int, default=1000) | |
parser.add_argument('--num-epochs', type=int, default=100) | |
parser.add_argument('--lr', type=float, default=0.003) | |
parser.add_argument('--num-workers', type=int, default=0) | |
parser.add_argument('--gpus', type=str, default=0) | |
parser.add_argument('--print-freq', type=int, default=10) | |
parser.add_argument('--save-epoch-freq', type=int, default=1) | |
parser.add_argument('--save-path', type=str, default="output") | |
parser.add_argument('--resume', type=str, default="", help="For training from one checkpoint") | |
parser.add_argument('--start-epoch', type=int, default=0, help="Corresponding to the epoch of resume ") | |
args = parser.parse_args() | |
# read data | |
trainloader, testloader = load_split_train_test(args, args.data_dir, .2) | |
print('classes ', trainloader.dataset.classes) | |
print('use ' + str(device)) | |
# get model | |
# model = mobilenetv2_19(num_classes = args.num_class) | |
# model = models.mobilenet_v2(num_classes = args.num_class) | |
model = models.resnet50(pretrained=True) | |
for param in model.parameters(): | |
param.requires_grad = False | |
new_fc = nn.Sequential( nn.Linear(2048, 512), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Linear(512, len(trainloader.dataset.classes)), | |
nn.LogSoftmax(dim=1) ) | |
model.fc = new_fc | |
if args.resume: | |
if os.path.isfile(args.resume): | |
model.load_state_dict(torch.load(args.resume).state_dict()) | |
else: | |
print(("=> no checkpoint found at '{}'".format(args.resume))) | |
model.to(device) | |
# model = torch.nn.DataParallel(model, device_ids=[int(i) for i in args.gpus.strip().split(',')]) | |
# define loss function | |
# criterion = nn.CrossEntropyLoss() | |
# criterion = nn.MSELoss() | |
criterion = nn.NLLLoss() | |
# Observe that all parameters are being optimized | |
# optimizer_ft = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.00004) | |
# optimizer_ft = optim.Adam(model.parameters(), lr=args.lr) | |
optimizer_ft = optim.Adam(model.fc.parameters(), lr=args.lr) | |
# Decay LR by a factor of 0.1 every 7 epochs | |
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=1, gamma=0.98) | |
model = train_model( args=args, | |
model=model, | |
criterion=criterion, | |
optimizer=optimizer_ft, | |
scheduler=exp_lr_scheduler | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment