Skip to content

Instantly share code, notes, and snippets.

@talesa
Created December 3, 2020 19:01
Show Gist options
  • Save talesa/78328c2db3ba8697729baa18dacc612d to your computer and use it in GitHub Desktop.
Save talesa/78328c2db3ba8697729baa18dacc612d to your computer and use it in GitHub Desktop.
simple requeueable slurm job
import argparse
from pathlib import Path
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.tensorboard
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from numpy import random
import matplotlib
matplotlib.use('agg')
# **kwargs below is used to allow additional unused keywords to be passed to the train function
def train(lr, n_epochs, logdir=None, checkpoint=None, **kwargs):
print(f'We are using learning rate {lr}.')
print(f'We are using n_epochs {n_epochs}.')
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = torchvision.datasets.CIFAR10(
# Try to use the datasets in /data/localhost/not-backed-up/datasets-ziz-all
# See [URL with discussion about how we organize datasets, TBC] for details.
root="/data/localhost/not-backed-up/datasets-ziz-all/torchvision/CIFAR10",
train=True, download=False, transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=10, shuffle=True)
testset = torchvision.datasets.CIFAR10(
root="/data/localhost/not-backed-up/datasets-ziz-all/torchvision/CIFAR10",
train=False, download=False, transform=transform
)
testloader = torch.utils.data.DataLoader(testset, batch_size=10, shuffle=False)
net = Net()
net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
if logdir:
# If logdir doesn't exist we create it.
# We expect logdir to be an absolute path, somewhere in /data/ziz/not-backed-up/scratch/$USER.
Path(logdir).mkdir(parents=True, exist_ok=True)
# logdir is on a network-filesystem (NFS), a drive that is shared from ziz to the compute nodes zizgpu0x,
# so it's slow to read or write and we don't want to write large files (e.g. checkpoints) or too often to that
# directory.
# However, it is a convenient place to put tensorboard logs (just the scalar metrics, not images etc) at because
# that will allow us monitoring the progress of all of our experiments, across all compute nodes zizgpu0x,
# by running just a single tensorboard service on ziz.
writer = torch.utils.tensorboard.SummaryWriter(log_dir=logdir)
if checkpoint:
# this will handle both relative and absolute paths
checkpoint_path = Path(logdir, checkpoint)
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint['net.state_dict'])
optimizer.load_state_dict(checkpoint['optimizer.state_dict'])
# if checkpoint['epoch']==i that means the last checkpoint was made at the end of epoch i
# so we restart from i+1
start_epoch = checkpoint['epoch'] + 1
print(f'Restarting training from epoch {start_epoch} of checkpoint: {checkpoint_path.absolute()}')
else:
start_epoch = 0
num_batches = len(trainloader)
for epoch in range(start_epoch, n_epochs): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(device), data[1].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if logdir:
# Log the value of the loss.
writer.add_scalar('loss/train', loss.item(), i)
# print statistics
running_loss += loss.item()
if i % 1000 == 0 and i != 0: # print every 1000 mini-batches
print(f"Epoch: {epoch}. Steps: {i}/{num_batches}. Loss: {running_loss/2000}")
running_loss = 0.0
if logdir:
# Let's save a checkpoint after every epoch.
torch.save({'net.state_dict': net.state_dict(),
'optimizer.state_dict': optimizer.state_dict(),
'epoch': epoch,},
Path(logdir, 'latest_checkpoint.torch'))
# At the end of the training evaluate the accuracy on the test set and save it to the logdir/results.
correct = 0
total = 0
# We don't need to compute gradients evaluating the performance on the test set.
with torch.no_grad():
for data in testloader:
inputs, labels = data[0].to(device), data[1].to(device)
outputs = net(inputs)
predicted = torch.argmax(outputs.data, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy on the test set: %d%%' % (100 * correct / total))
if logdir:
# Save the results to logdir, on the central storage.
torch.save({'accuracy': correct / total}, Path(logdir, 'results.torch'))
print("Finished training.")
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float)
parser.add_argument('--n_epochs', type=int)
parser.add_argument('--checkpoint', type=str, default=None,
help="Absolute or relative (wrt to the specified --logdir) path to the checkpoint.")
# On our cluster this would be intended to be on one of the network-filesystem (NFS) drives from ziz,
# which is available on all zizgpu0x, so at /data/ziz/not-backed-up/scratch/$USER
parser.add_argument('--logdir', type=str, default=None,
help="Absolute or relative (wrt to working directory) log directory.")
args, unknown_args = parser.parse_known_args()
print(f"Unrecognized args: {unknown_args}")
train(**args.__dict__)
# Equivalent to
# model.train(lr=args.lr, n_epochs=args.n_epochs)
# model.train(lr=args.__dict__['lr'], n_epochs=args.__dict__['n_epochs'])
#!/bin/bash
# This script is a working example which allows you to do `scontrol requeue JOBID`.
# This example builds on top of your understanding of `single_training.sh`.
# Usage: `sbatch --output=/data/ziz/not-backed-up/scratch/$USER/slurm-%j.o --error=/data/ziz/not-backed-up/scratch/$USER/slurm-%j.o /data/ziz/not-backed-up/software/ziz_toolkit/slurm_gpu_examples/example_basic/simple_requeueable_job.sh`
#SBATCH --job-name=simple_requeueable_job
#SBATCH --partition=ziz-gpu
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=1
#SBATCH --time=14-00:00:00
#SBATCH --mem=5G
#SBATCH --ntasks=1
# THE PART NEW IN simple_requeueable_job.sh
# This allows your job to be requeued.
#SBATCH --requeue
# The setting makes sure that once your job is restarted it doesn't overwrite the --output and --error logs from before
# the restart, but just appends to them.
#SBATCH --open-mode=append
export PATH_TO_CONDA="/data/ziz/not-backed-up/software/ziz_toolkit/miniconda3"
# Activate conda virtual environment
source $PATH_TO_CONDA/bin/activate example_environment
# Just to make sure the directories exists
mkdir -p /data/ziz/not-backed-up/scratch/$USER/ziz_toolkit/slurm_gpu_examples/example_basic/logs
# If it's a restart adds "--checkpoint ..." to the python command.
if [[ $SLURM_RESTART_COUNT -gt 0 ]]; then
echo "Restarting count: $SLURM_RESTART_COUNT"
export CHECKPOINT="--checkpoint /data/ziz/not-backed-up/scratch/$USER/ziz_toolkit/slurm_gpu_examples/example_basic/logs/latest_checkpoint.torch"
fi
echo "python
/data/ziz/not-backed-up/software/ziz_toolkit/slurm_gpu_examples/example_basic/model.py
--lr 0.02
--n_epochs 10
--logdir /data/ziz/not-backed-up/scratch/$USER/ziz_toolkit/slurm_gpu_examples/example_basic/logs
$CHECKPOINT"
python -u /data/ziz/not-backed-up/software/ziz_toolkit/slurm_gpu_examples/example_basic/model.py \
--lr 0.02 \
--n_epochs 10 \
--logdir /data/ziz/not-backed-up/scratch/$USER/ziz_toolkit/slurm_gpu_examples/example_basic/logs \
$CHECKPOINT
echo "Job completed."
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment