Last active
May 11, 2022 04:10
-
-
Save haythamfayek/9215084478e6638dff450f0158171efc to your computer and use it in GitHub Desktop.
PyTorch template.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2020 Haytham Fayek. | |
# All rights reserved. | |
# Subject to the MIT license: | |
# https://opensource.org/licenses/MIT. | |
import argparse | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import datasets, transforms | |
# Training settings | |
parser = argparse.ArgumentParser() | |
# Input | |
parser.add_argument('--train-set', type=str, default='train.h5', | |
help='Path to training data.') | |
parser.add_argument('--test-set', type=str, default='test.h5', | |
help='Path to test data.') | |
# Settings | |
parser.add_argument('--seed', type=int, default=0, | |
help='Random seed.') | |
parser.add_argument('--cpu', action='store_true', default=False, | |
help='Disable CUDA.') | |
parser.add_argument('--multi-gpus', action='store_true', default=False, | |
help='Use all available GPUs.') | |
parser.add_argument('--resume', action='store_true', default=False, | |
help='Resume training.') | |
parser.add_argument('--model', type=str, default='model', | |
help='Model to train.') | |
parser.add_argument('--optimizer', type=str, default='adam', | |
help='Optimizer.') | |
parser.add_argument('--lr', type=float, default=0.0001, | |
help='Learning rate.') | |
parser.add_argument('--l2', type=float, default=0.0001, | |
help='Weight decay.') | |
parser.add_argument('--batch-size', type=int, default=128, | |
help='Batch size for training.') | |
parser.add_argument('--epochs', type=int, default=10, | |
help='Number of epochs to train.') | |
parser.add_argument('--early-stopping', action='store_true', default=False, | |
help='Early stopping.') | |
parser.add_argument('--anneal-learning-rate', action='store_true', default=False, | |
help='Anneal Learning Rate.') | |
parser.add_argument('--patience', type=int, default=10, | |
help='Number of epochs before early stopping.') | |
# Output | |
parser.add_argument('--save-model', action='store_true', default=False, | |
help='Save current model.') | |
parser.add_argument('--model-dir', type=str, default='models', | |
help='Path to model.') | |
parser.add_argument('--model-name', type=str, default='model.pt', | |
help='Model name.') | |
parser.add_argument('--infer-only', action='store_true', default=False, | |
help='Run in test mode only.') | |
# Dataset | |
class Dataset(torch.utils.data.Dataset): | |
def __init__(self, dataset_pt, stats=None): | |
pass | |
def __len__(self): | |
pass | |
def __getitem__(self, index): | |
pass | |
# Model | |
class Model(nn.Module): | |
def __init__(self, input_dim, output_dim): | |
super().__init__() | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, x): | |
x = x.reshape((x.shape[0], -1)) | |
x = self.linear(x) | |
return x | |
def save_state(checkpoint_pt, epoch, model, optimizer, scheduler, | |
train_loss, train_perf, test_loss, test_perf, | |
best_perf, patience, best=False): | |
kwargs = { | |
'epoch': epoch, | |
'optimizer_state_dict': optimizer.state_dict(), | |
'scheduler_state_dict': scheduler.state_dict(), | |
'train_loss': train_loss, | |
'train_perf': train_perf, | |
'test_loss': test_loss, | |
'test_perf': test_perf, | |
'best_perf': best_perf, | |
'patience': patience, | |
} | |
if best and isinstance(model, nn.DataParallel): # unwrap model when saving best | |
kwargs['model_state_dict'] = model.module.state_dict() | |
else: | |
kwargs['model_state_dict'] = model.state_dict() | |
if best: | |
checkpoint_pt += '.best' | |
torch.save(kwargs, checkpoint_pt) | |
def load_state(checkpoint_pt, model, optimizer, scheduler): | |
checkpoint = torch.load(checkpoint_pt) | |
sepoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0 | |
model.load_state_dict(checkpoint['model_state_dict']) | |
if 'optimizer_state_dict' in checkpoint: | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
if 'scheduler_state_dict' in checkpoint: | |
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
train_loss = checkpoint['train_loss'] if 'train_loss' in checkpoint else [] | |
train_perf = checkpoint['train_perf'] if 'train_perf' in checkpoint else [] | |
test_loss = checkpoint['test_loss'] if 'test_loss' in checkpoint else [] | |
test_perf = checkpoint['test_perf'] if 'test_perf' in checkpoint else [] | |
best_perf = checkpoint['best_perf'] if 'best_perf' in checkpoint else 0 | |
patience = checkpoint['patience'] if 'patience' in checkpoint else 0 | |
return sepoch, model, optimizer, scheduler, \ | |
train_loss, train_perf, test_loss, test_perf, \ | |
best_perf, patience | |
def train(model, device, train_loader, optimizer): | |
model.train() | |
for data, target in train_loader: | |
data = data.to(device) | |
target = target.to(device) | |
# model.zero_grad() | |
for p in model.parameters(): | |
p.grad = None | |
output = model(data) | |
loss = F.cross_entropy(output, target) | |
loss.backward() | |
optimizer.step() | |
def test(model, device, test_loader): | |
model.eval() | |
test_loss, correct, examples = 0., 0, 0 | |
with torch.no_grad(): | |
for data, target in test_loader: | |
data = data.to(device) | |
target = target.to(device) | |
output = model(data) | |
test_loss += F.cross_entropy(output, | |
target, reduction='sum').item() | |
label = output.argmax(dim=1, keepdim=True) | |
correct += label.eq(target.view_as(label)).sum().item() | |
examples += len(data) | |
test_loss /= examples | |
perf = correct / examples | |
return test_loss, perf | |
def main(args): | |
# Seed | |
# np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(args.seed) | |
# torch.backends.cudnn.benchmark = True | |
# torch.backends.cudnn.deterministic = True | |
# Infra | |
use_cuda = not args.cpu and torch.cuda.is_available() | |
device = torch.device('cuda' if use_cuda else 'cpu') | |
if args.save_model and not os.path.isdir(args.model_dir): | |
os.makedirs(args.model_dir) | |
checkpoint_pt = os.path.join(args.model_dir, args.model_name) | |
# Dataset | |
train_set = datasets.CIFAR10( | |
args.train_set, transform=transforms.ToTensor(), train=True, download=True) | |
test_set = datasets.CIFAR10( | |
args.test_set, transform=transforms.ToTensor(), train=False, download=True) | |
train_set.input_dim, train_set.output_dim = 3072, 10 | |
# train_set = Dataset(args.train_set) | |
# test_set = Dataset(args.test_set, train_set.stats) | |
loader_kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {} | |
train_loader = torch.utils.data.DataLoader(train_set, | |
batch_size=args.batch_size, | |
shuffle=True, | |
**loader_kwargs) | |
test_loader = torch.utils.data.DataLoader(test_set, | |
batch_size=args.batch_size, | |
shuffle=False, | |
**loader_kwargs) | |
# Model | |
model = Model(input_dim=train_set.input_dim, | |
output_dim=train_set.output_dim) | |
model = model.to(device) | |
# GPU / multi-GPU | |
if use_cuda and args.multi_gpus and torch.cuda.device_count() > 1: | |
model = nn.DataParallel(model) | |
print('DataParallel! Using', torch.cuda.device_count(), 'GPUs!') | |
else: | |
print('Single CPU/GPU! Using', device) | |
# Optimizer and scheduler | |
if args.optimizer == 'sgd': | |
optimizer = torch.optim.SGD(model.parameters(), | |
lr=args.lr, | |
weight_decay=args.l2) | |
elif args.optimizer == 'rmsprop': | |
optimizer = torch.optim.RMSprop(model.parameters(), | |
lr=args.lr, | |
weight_decay=args.l2) | |
elif args.optimizer == 'adam': | |
optimizer = torch.optim.Adam(model.parameters(), | |
lr=args.lr, | |
weight_decay=args.l2) | |
else: | |
assert False, 'Unknown optimizer.' | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, | |
mode='max', | |
factor=0.1, | |
patience=int( | |
args.patience / 2), | |
verbose=True) | |
# Inference | |
if args.infer_only: | |
if os.path.isfile(checkpoint_pt): | |
print('Testing:', checkpoint_pt) | |
_, model, optimizer, scheduler, _, _, _, _, _, _ = \ | |
load_state(checkpoint_pt, model, optimizer, scheduler) | |
print('\n' + 'Hyperparamters' + '\n' + str(args)) | |
print('\n' + 'Model' + '\n' + str(model) + '\n') | |
print('Start testing...') | |
test_loss, test_perf = test(model, device, test_loader) | |
print(('Test loss: {:.3f}, Test Perf: {:.3f}%.').format(test_loss, | |
100. * test_perf)) | |
else: | |
print('Could not find model to test.') | |
return # inference done, nothing else to do here. | |
# Initialize or load from existing checkpoint | |
if args.resume and os.path.isfile(checkpoint_pt): | |
print('Continue training from:', checkpoint_pt) | |
sepoch, model, optimizer, scheduler, \ | |
train_loss_lst, train_perf_lst, test_loss_lst, test_perf_lst, \ | |
best_perf, patience = load_state( | |
checkpoint_pt, model, optimizer, scheduler) | |
else: | |
sepoch = 0 | |
train_loss_lst, train_perf_lst, test_loss_lst, test_perf_lst = [], [], [], [] | |
best_perf, patience = 0., 0 | |
print('\n' + 'Hyperparamters' + '\n' + str(args)) | |
print('\n' + 'Model' + '\n' + str(model) + '\n') | |
print('Start training...') | |
train_loss, train_perf = test(model, device, train_loader) | |
test_loss, test_perf = test(model, device, test_loader) | |
print(('Epoch {:03d}. Train loss: {:.3f}, Train Perf: {:.3f}%' | |
+ '. Test loss: {:.3f}, Test Perf: {:.3f}%.').format(sepoch, | |
train_loss, | |
100. * train_perf, | |
test_loss, | |
100. * test_perf)) | |
train_loss_lst.append(train_loss) | |
train_perf_lst.append(train_perf) | |
test_loss_lst.append(test_loss) | |
test_perf_lst.append(test_perf) | |
# Training loop | |
for epoch in range(sepoch + 1, args.epochs + 1): | |
# Train | |
train(model, device, train_loader, optimizer) | |
# Eval | |
train_loss, train_perf = test(model, device, train_loader) | |
test_loss, test_perf = test(model, device, test_loader) | |
print(('Epoch {:03d}. Train loss: {:.3f}, Train Perf: {:.3f}%' | |
+ '. Test loss: {:.3f}, Test Perf: {:.3f}%.').format(epoch, | |
train_loss, | |
100. * train_perf, | |
test_loss, | |
100. * test_perf)) | |
train_loss_lst.append(train_loss) | |
train_perf_lst.append(train_perf) | |
test_loss_lst.append(test_loss) | |
test_perf_lst.append(test_perf) | |
if args.anneal_learning_rate: | |
scheduler.step(test_perf) | |
# Monitor best performance so far assuming higher better | |
if test_perf > best_perf: | |
best_perf, patience = test_perf, 0 | |
print('Best Model at Epoch', str(epoch)) | |
if args.save_model: | |
save_state(checkpoint_pt, epoch, model, optimizer, scheduler, | |
train_loss_lst, train_perf_lst, test_loss_lst, test_perf_lst, | |
best_perf, patience, best=True) | |
else: | |
patience += 1 | |
if args.save_model: | |
save_state(checkpoint_pt, epoch, model, optimizer, scheduler, | |
train_loss_lst, train_perf_lst, test_loss_lst, test_perf_lst, | |
best_perf, patience) | |
if args.early_stopping and patience >= args.patience: | |
print('Early Stopping!') | |
break | |
if __name__ == '__main__': | |
pargs = parser.parse_args() | |
main(pargs) | |
print('Success!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
A template for a PyTorch project. This boilerplate code implements common command line arguments, base dataset and model classes, save and load functions, train and test loops, and a main function to set up infrastructure, datasets and data loaders, model, optimizer, auto-checkpointing, and early stopping. Contributions welcome!