Skip to content

Instantly share code, notes, and snippets.

@haythamfayek
Last active May 11, 2022 04:10
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save haythamfayek/9215084478e6638dff450f0158171efc to your computer and use it in GitHub Desktop.
Save haythamfayek/9215084478e6638dff450f0158171efc to your computer and use it in GitHub Desktop.
PyTorch template.
# Copyright (c) 2020 Haytham Fayek.
# All rights reserved.
# Subject to the MIT license:
# https://opensource.org/licenses/MIT.
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
# Training settings
parser = argparse.ArgumentParser()
# Input
parser.add_argument('--train-set', type=str, default='train.h5',
help='Path to training data.')
parser.add_argument('--test-set', type=str, default='test.h5',
help='Path to test data.')
# Settings
parser.add_argument('--seed', type=int, default=0,
help='Random seed.')
parser.add_argument('--cpu', action='store_true', default=False,
help='Disable CUDA.')
parser.add_argument('--multi-gpus', action='store_true', default=False,
help='Use all available GPUs.')
parser.add_argument('--resume', action='store_true', default=False,
help='Resume training.')
parser.add_argument('--model', type=str, default='model',
help='Model to train.')
parser.add_argument('--optimizer', type=str, default='adam',
help='Optimizer.')
parser.add_argument('--lr', type=float, default=0.0001,
help='Learning rate.')
parser.add_argument('--l2', type=float, default=0.0001,
help='Weight decay.')
parser.add_argument('--batch-size', type=int, default=128,
help='Batch size for training.')
parser.add_argument('--epochs', type=int, default=10,
help='Number of epochs to train.')
parser.add_argument('--early-stopping', action='store_true', default=False,
help='Early stopping.')
parser.add_argument('--anneal-learning-rate', action='store_true', default=False,
help='Anneal Learning Rate.')
parser.add_argument('--patience', type=int, default=10,
help='Number of epochs before early stopping.')
# Output
parser.add_argument('--save-model', action='store_true', default=False,
help='Save current model.')
parser.add_argument('--model-dir', type=str, default='models',
help='Path to model.')
parser.add_argument('--model-name', type=str, default='model.pt',
help='Model name.')
parser.add_argument('--infer-only', action='store_true', default=False,
help='Run in test mode only.')
# Dataset
class Dataset(torch.utils.data.Dataset):
def __init__(self, dataset_pt, stats=None):
pass
def __len__(self):
pass
def __getitem__(self, index):
pass
# Model
class Model(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = x.reshape((x.shape[0], -1))
x = self.linear(x)
return x
def save_state(checkpoint_pt, epoch, model, optimizer, scheduler,
train_loss, train_perf, test_loss, test_perf,
best_perf, patience, best=False):
kwargs = {
'epoch': epoch,
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'train_loss': train_loss,
'train_perf': train_perf,
'test_loss': test_loss,
'test_perf': test_perf,
'best_perf': best_perf,
'patience': patience,
}
if best and isinstance(model, nn.DataParallel): # unwrap model when saving best
kwargs['model_state_dict'] = model.module.state_dict()
else:
kwargs['model_state_dict'] = model.state_dict()
if best:
checkpoint_pt += '.best'
torch.save(kwargs, checkpoint_pt)
def load_state(checkpoint_pt, model, optimizer, scheduler):
checkpoint = torch.load(checkpoint_pt)
sepoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0
model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
train_loss = checkpoint['train_loss'] if 'train_loss' in checkpoint else []
train_perf = checkpoint['train_perf'] if 'train_perf' in checkpoint else []
test_loss = checkpoint['test_loss'] if 'test_loss' in checkpoint else []
test_perf = checkpoint['test_perf'] if 'test_perf' in checkpoint else []
best_perf = checkpoint['best_perf'] if 'best_perf' in checkpoint else 0
patience = checkpoint['patience'] if 'patience' in checkpoint else 0
return sepoch, model, optimizer, scheduler, \
train_loss, train_perf, test_loss, test_perf, \
best_perf, patience
def train(model, device, train_loader, optimizer):
model.train()
for data, target in train_loader:
data = data.to(device)
target = target.to(device)
# model.zero_grad()
for p in model.parameters():
p.grad = None
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
def test(model, device, test_loader):
model.eval()
test_loss, correct, examples = 0., 0, 0
with torch.no_grad():
for data, target in test_loader:
data = data.to(device)
target = target.to(device)
output = model(data)
test_loss += F.cross_entropy(output,
target, reduction='sum').item()
label = output.argmax(dim=1, keepdim=True)
correct += label.eq(target.view_as(label)).sum().item()
examples += len(data)
test_loss /= examples
perf = correct / examples
return test_loss, perf
def main(args):
# Seed
# np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(args.seed)
# torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
# Infra
use_cuda = not args.cpu and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
if args.save_model and not os.path.isdir(args.model_dir):
os.makedirs(args.model_dir)
checkpoint_pt = os.path.join(args.model_dir, args.model_name)
# Dataset
train_set = datasets.CIFAR10(
args.train_set, transform=transforms.ToTensor(), train=True, download=True)
test_set = datasets.CIFAR10(
args.test_set, transform=transforms.ToTensor(), train=False, download=True)
train_set.input_dim, train_set.output_dim = 3072, 10
# train_set = Dataset(args.train_set)
# test_set = Dataset(args.test_set, train_set.stats)
loader_kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(train_set,
batch_size=args.batch_size,
shuffle=True,
**loader_kwargs)
test_loader = torch.utils.data.DataLoader(test_set,
batch_size=args.batch_size,
shuffle=False,
**loader_kwargs)
# Model
model = Model(input_dim=train_set.input_dim,
output_dim=train_set.output_dim)
model = model.to(device)
# GPU / multi-GPU
if use_cuda and args.multi_gpus and torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
print('DataParallel! Using', torch.cuda.device_count(), 'GPUs!')
else:
print('Single CPU/GPU! Using', device)
# Optimizer and scheduler
if args.optimizer == 'sgd':
optimizer = torch.optim.SGD(model.parameters(),
lr=args.lr,
weight_decay=args.l2)
elif args.optimizer == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(),
lr=args.lr,
weight_decay=args.l2)
elif args.optimizer == 'adam':
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.l2)
else:
assert False, 'Unknown optimizer.'
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
mode='max',
factor=0.1,
patience=int(
args.patience / 2),
verbose=True)
# Inference
if args.infer_only:
if os.path.isfile(checkpoint_pt):
print('Testing:', checkpoint_pt)
_, model, optimizer, scheduler, _, _, _, _, _, _ = \
load_state(checkpoint_pt, model, optimizer, scheduler)
print('\n' + 'Hyperparamters' + '\n' + str(args))
print('\n' + 'Model' + '\n' + str(model) + '\n')
print('Start testing...')
test_loss, test_perf = test(model, device, test_loader)
print(('Test loss: {:.3f}, Test Perf: {:.3f}%.').format(test_loss,
100. * test_perf))
else:
print('Could not find model to test.')
return # inference done, nothing else to do here.
# Initialize or load from existing checkpoint
if args.resume and os.path.isfile(checkpoint_pt):
print('Continue training from:', checkpoint_pt)
sepoch, model, optimizer, scheduler, \
train_loss_lst, train_perf_lst, test_loss_lst, test_perf_lst, \
best_perf, patience = load_state(
checkpoint_pt, model, optimizer, scheduler)
else:
sepoch = 0
train_loss_lst, train_perf_lst, test_loss_lst, test_perf_lst = [], [], [], []
best_perf, patience = 0., 0
print('\n' + 'Hyperparamters' + '\n' + str(args))
print('\n' + 'Model' + '\n' + str(model) + '\n')
print('Start training...')
train_loss, train_perf = test(model, device, train_loader)
test_loss, test_perf = test(model, device, test_loader)
print(('Epoch {:03d}. Train loss: {:.3f}, Train Perf: {:.3f}%'
+ '. Test loss: {:.3f}, Test Perf: {:.3f}%.').format(sepoch,
train_loss,
100. * train_perf,
test_loss,
100. * test_perf))
train_loss_lst.append(train_loss)
train_perf_lst.append(train_perf)
test_loss_lst.append(test_loss)
test_perf_lst.append(test_perf)
# Training loop
for epoch in range(sepoch + 1, args.epochs + 1):
# Train
train(model, device, train_loader, optimizer)
# Eval
train_loss, train_perf = test(model, device, train_loader)
test_loss, test_perf = test(model, device, test_loader)
print(('Epoch {:03d}. Train loss: {:.3f}, Train Perf: {:.3f}%'
+ '. Test loss: {:.3f}, Test Perf: {:.3f}%.').format(epoch,
train_loss,
100. * train_perf,
test_loss,
100. * test_perf))
train_loss_lst.append(train_loss)
train_perf_lst.append(train_perf)
test_loss_lst.append(test_loss)
test_perf_lst.append(test_perf)
if args.anneal_learning_rate:
scheduler.step(test_perf)
# Monitor best performance so far assuming higher better
if test_perf > best_perf:
best_perf, patience = test_perf, 0
print('Best Model at Epoch', str(epoch))
if args.save_model:
save_state(checkpoint_pt, epoch, model, optimizer, scheduler,
train_loss_lst, train_perf_lst, test_loss_lst, test_perf_lst,
best_perf, patience, best=True)
else:
patience += 1
if args.save_model:
save_state(checkpoint_pt, epoch, model, optimizer, scheduler,
train_loss_lst, train_perf_lst, test_loss_lst, test_perf_lst,
best_perf, patience)
if args.early_stopping and patience >= args.patience:
print('Early Stopping!')
break
if __name__ == '__main__':
pargs = parser.parse_args()
main(pargs)
print('Success!')
@haythamfayek
Copy link
Author

A template for a PyTorch project. This boilerplate code implements common command line arguments, base dataset and model classes, save and load functions, train and test loops, and a main function to set up infrastructure, datasets and data loaders, model, optimizer, auto-checkpointing, and early stopping. Contributions welcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment