Created
February 1, 2019 06:59
-
-
Save luistelmocosta/ebfde00e892750dfed726b345dffb9e9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
def main(args): | |
import os | |
import sys | |
import glob | |
import random | |
import pickle | |
import numpy as np | |
from PIL import Image | |
import time | |
import copy | |
from data_set.dataset_train import get_dataset | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
import torchvision.transforms as transforms | |
import torch.utils.data | |
import torchvision.models as models | |
from torchvision import datasets, models, transforms | |
import torch.optim as optim | |
from torch.optim import lr_scheduler | |
from user_define import hyperparameter as hp | |
from utils import progress_bar | |
SEED = 101 | |
np.random.seed(SEED) | |
os.environ["CUDA_VISIBLE_DEVICES"] = '0' | |
#print(torch.cuda.current_device()) | |
use_gpu = torch.cuda.is_available() | |
#print(USE_CUDA) | |
trans_train = transforms.Compose([ | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomVerticalFlip(), | |
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), | |
transforms.RandomGrayscale(p=0.1), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
trans_test = transforms.Compose([ | |
transforms.ToTensor(), | |
]) | |
data_transforms = { | |
'train': transforms.Compose([ | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomVerticalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
'valid': transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
} | |
image_datasets = {x: datasets.ImageFolder(os.path.join(args.data_loc, x), | |
data_transforms[x]) | |
for x in ['train', 'valid']} | |
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size, | |
shuffle=True, num_workers=args.num_workers) | |
for x in ['train', 'valid']} | |
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} | |
print(dataset_sizes) | |
class_names = image_datasets['train'].classes | |
print(class_names) | |
print('Data loading END') | |
def train_model(model, criterion, optimizer, scheduler, num_epochs=25): | |
since = time.time() | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
best_acc = 0.0 | |
for epoch in range(num_epochs): | |
print('Epoch {}/{}'.format(epoch, num_epochs - 1)) | |
print('-' * 10) | |
# Each epoch has a training and validation phase | |
for phase in ['train', 'valid']: | |
if phase == 'train': | |
scheduler.step() | |
model.train(True) # Set model to training mode | |
else: | |
model.train(False) # Set model to evaluate mode | |
model.eval() | |
running_loss = 0.0 | |
running_corrects = 0 | |
total = 0 | |
correct = 0 | |
valid_losses = [] | |
batch_id = 0 | |
# Iterate over data. | |
for data in dataloaders[phase]: | |
# get the inputs | |
batch_id += 1 | |
inputs, labels = data | |
# wrap them in Variable | |
if use_gpu: | |
inputs = Variable(inputs.cuda()) | |
labels = Variable(labels.cuda()) | |
else: | |
inputs, labels = Variable(inputs), Variable(labels) | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# forward | |
outputs = model(inputs) | |
# for nets that have multiple outputs such as inception | |
if isinstance(outputs, tuple): | |
loss = sum((criterion(o,labels) for o in outputs)) | |
else: | |
loss = criterion(outputs, labels) | |
valid_losses.append(loss.item()) | |
# backward + optimize only if in training phase | |
if phase == 'train': | |
_, preds = torch.max(outputs[0].data, 1) | |
loss.backward() | |
optimizer.step() | |
else: | |
_, preds = torch.max(outputs.data, 1) | |
# statistics | |
total += labels.size(0) | |
running_corrects += torch.sum(preds == labels.data) | |
valid_loss = np.average(valid_losses) | |
#actual_acc = running_corrects / dataset_sizes[phase] | |
progress_bar(batch_id, len(dataloaders[phase]), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' | |
% (valid_loss, 100.*running_corrects/total, running_corrects, total)) | |
epoch_loss = running_loss / dataset_sizes[phase] | |
epoch_acc = running_corrects / dataset_sizes[phase] | |
print('{} Loss: {:.4f} Acc: {:.4f}'.format( | |
phase, epoch_loss, epoch_acc)) | |
# deep copy the model | |
if phase == 'valid' and epoch_acc > best_acc: | |
best_acc = epoch_acc | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
time_elapsed = time.time() - since | |
print('Training complete in {:.0f}m {:.0f}s'.format( | |
time_elapsed // 60, time_elapsed % 60)) | |
print('Best valid Acc: {:4f}'.format(best_acc)) | |
# load best model weights | |
model.load_state_dict(best_model_wts) | |
return model | |
model_ft = models.inception_v3(pretrained=True) | |
num_ftrs = model_ft.fc.in_features | |
model_ft.fc = nn.Linear(num_ftrs, 2) | |
if use_gpu: | |
model_ft = model_ft.cuda() | |
criterion = nn.CrossEntropyLoss() | |
# all parameters are being optimized | |
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) | |
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=4, gamma=0.1) | |
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, | |
num_epochs=args.epochs) | |
torch.save(model_ft.state_dict(), args.save_path) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='train inception_v3') | |
parser.add_argument('--data_loc', type=str) | |
parser.add_argument('--save_path', type=str) | |
parser.add_argument('--epochs', type=int, default=10) | |
parser.add_argument('--batch_size', type=int, default=16) | |
parser.add_argument('--num_workers', type=int, default=8) | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment