Created December 3, 2020 19:01
simple requeueable slurm job
import argparse
from pathlib import Path
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.tensorboard
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from numpy import random
import matplotlib
# **kwargs below is used to allow additional unused keywords to be passed to the train function
def train(lr, n_epochs, logdir=None, checkpoint=None, **kwargs):
print(f'We are using learning rate {lr}.')
print(f'We are using n_epochs {n_epochs}.')
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
trainset = torchvision.datasets.CIFAR10(
# Try to use the datasets in /data/localhost/not-backed-up/datasets-ziz-all
# See [URL with discussion about how we organize datasets, TBC] for details.
train=True, download=False, transform=transform
trainloader =, batch_size=10, shuffle=True)
testset = torchvision.datasets.CIFAR10(
train=False, download=False, transform=transform
testloader =, batch_size=10, shuffle=False)
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
if logdir:
# If logdir doesn't exist we create it.
# We expect logdir to be an absolute path, somewhere in /data/ziz/not-backed-up/scratch/$USER.
Path(logdir).mkdir(parents=True, exist_ok=True)
# logdir is on a network-filesystem (NFS), a drive that is shared from ziz to the compute nodes zizgpu0x,
# so it's slow to read or write and we don't want to write large files (e.g. checkpoints) or too often to that
# directory.
# However, it is a convenient place to put tensorboard logs (just the scalar metrics, not images etc) at because
# that will allow us monitoring the progress of all of our experiments, across all compute nodes zizgpu0x,
# by running just a single tensorboard service on ziz.
writer = torch.utils.tensorboard.SummaryWriter(log_dir=logdir)
if checkpoint:
# this will handle both relative and absolute paths
checkpoint_path = Path(logdir, checkpoint)
checkpoint = torch.load(checkpoint_path)
# if checkpoint['epoch']==i that means the last checkpoint was made at the end of epoch i
# so we restart from i+1
start_epoch = checkpoint['epoch'] + 1
print(f'Restarting training from epoch {start_epoch} of checkpoint: {checkpoint_path.absolute()}')
start_epoch = 0
num_batches = len(trainloader)
for epoch in range(start_epoch, n_epochs): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(device), data[1].to(device)
# zero the parameter gradients
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
if logdir:
# Log the value of the loss.
writer.add_scalar('loss/train', loss.item(), i)
# print statistics
running_loss += loss.item()
if i % 1000 == 0 and i != 0: # print every 1000 mini-batches
print(f"Epoch: {epoch}. Steps: {i}/{num_batches}. Loss: {running_loss/2000}")
running_loss = 0.0
if logdir:
# Let's save a checkpoint after every epoch.{'net.state_dict': net.state_dict(),
'optimizer.state_dict': optimizer.state_dict(),
'epoch': epoch,},
Path(logdir, 'latest_checkpoint.torch'))
# At the end of the training evaluate the accuracy on the test set and save it to the logdir/results.
correct = 0
total = 0
# We don't need to compute gradients evaluating the performance on the test set.
with torch.no_grad():
for data in testloader:
inputs, labels = data[0].to(device), data[1].to(device)
outputs = net(inputs)
predicted = torch.argmax(, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy on the test set: %d%%' % (100 * correct / total))
if logdir:
# Save the results to logdir, on the central storage.{'accuracy': correct / total}, Path(logdir, 'results.torch'))
print("Finished training.")
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float)
parser.add_argument('--n_epochs', type=int)
parser.add_argument('--checkpoint', type=str, default=None,
help="Absolute or relative (wrt to the specified --logdir) path to the checkpoint.")
# On our cluster this would be intended to be on one of the network-filesystem (NFS) drives from ziz,
# which is available on all zizgpu0x, so at /data/ziz/not-backed-up/scratch/$USER
parser.add_argument('--logdir', type=str, default=None,
help="Absolute or relative (wrt to working directory) log directory.")
args, unknown_args = parser.parse_known_args()
print(f"Unrecognized args: {unknown_args}")
# Equivalent to
# model.train(, n_epochs=args.n_epochs)
# model.train(lr=args.__dict__['lr'], n_epochs=args.__dict__['n_epochs'])
# This script is a working example which allows you to do `scontrol requeue JOBID`.
# This example builds on top of your understanding of ``.
# Usage: `sbatch --output=/data/ziz/not-backed-up/scratch/$USER/slurm-%j.o --error=/data/ziz/not-backed-up/scratch/$USER/slurm-%j.o /data/ziz/not-backed-up/software/ziz_toolkit/slurm_gpu_examples/example_basic/`
#SBATCH --job-name=simple_requeueable_job
#SBATCH --partition=ziz-gpu
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=1
#SBATCH --time=14-00:00:00
#SBATCH --mem=5G
#SBATCH --ntasks=1
# This allows your job to be requeued.
#SBATCH --requeue
# The setting makes sure that once your job is restarted it doesn't overwrite the --output and --error logs from before
# the restart, but just appends to them.
#SBATCH --open-mode=append
export PATH_TO_CONDA="/data/ziz/not-backed-up/software/ziz_toolkit/miniconda3"
# Activate conda virtual environment
source $PATH_TO_CONDA/bin/activate example_environment
# Just to make sure the directories exists
mkdir -p /data/ziz/not-backed-up/scratch/$USER/ziz_toolkit/slurm_gpu_examples/example_basic/logs
# If it's a restart adds "--checkpoint ..." to the python command.
if [[ $SLURM_RESTART_COUNT -gt 0 ]]; then
echo "Restarting count: $SLURM_RESTART_COUNT"
export CHECKPOINT="--checkpoint /data/ziz/not-backed-up/scratch/$USER/ziz_toolkit/slurm_gpu_examples/example_basic/logs/latest_checkpoint.torch"
echo "python
--lr 0.02
--n_epochs 10
--logdir /data/ziz/not-backed-up/scratch/$USER/ziz_toolkit/slurm_gpu_examples/example_basic/logs
python -u /data/ziz/not-backed-up/software/ziz_toolkit/slurm_gpu_examples/example_basic/ \
--lr 0.02 \
--n_epochs 10 \
--logdir /data/ziz/not-backed-up/scratch/$USER/ziz_toolkit/slurm_gpu_examples/example_basic/logs \
echo "Job completed."
