Skip to content

Instantly share code, notes, and snippets.



Last active Mar 14, 2019
What would you like to do?
Example startup script / boot script "user data" for running machine learning experiments on EC2 Spot Instances with git & dvc


  • Write your training script so that it can be killed, and then automatically resumes from the beginning of the current epoch when restarted. (See for an example training loop incorporating these recommendations.)
    • Save checkpoints at every epoch... (See for save_training_state helper function.)
      • model(s)
      • optimizer(s)
      • any hyperparameter schedules — I usually write the epoch number to a JSON file and compute the hyperparameter schedules as a function of the epoch number.
    • At the beginning of training, check for any saved training checkpoints and load all relevent info (models, optimizers, hyperparameter schedules). (See for load_training_state helper function.)
    • Consider using smaller epochs by limiting the number of batches pulled from your (shuffled) dataloader during each epoch.
      • This will cause your training to be checkpointed more often, so in the case that your spot instance is shut down, you will limit the amount of lost training.
      • Consider only computing validation metrics every 3-5 epochs when using shorter epochs.
    • Test by running your script, killing it in the middle of the 2nd epoch, and restarting your script with the same command. Verify that it loads the model checkpoint and continues at the beginning of the second epoch.
  • Make sure your training script will not continue training if re-started after training has completed in the past.
    • Your instance will shutdown when your training script completes (assuming you follow the rest of these instructions), and then AWS will attempt to restart or relaunch your instance in order to maintain your spot "fleet". When this happens, you would like the training script to recognize that training is completed and quickly return, so that your instance will shutdown again. Then AWS will give-up on maintaining your fleet when it fails a few times to startup (i.e. when it shuts down right away after starting up).
    • One idea is to make your training schedule based on the epoch number (as mentioned above), and then make sure that your training script will abort if started on the final epoch. If you follow the other examples here, it will already work like that, but if you do it a different way (or use early stopping, for example), you will need to set this up another way.
    • Another idea would be to write a final checkpoint after your training completes, or even just write an empty file training_completed.txt, and check for these files before launching training.
  • Create a machine image (AMI) with your repo and training data on it OR put your code on github and data in S3 so that you can download them on startup.
    • Make sure git status doesn't show any changes which need to be committed — verify that you can checkout a different branch or commit without git complaining.
    • Make sure that any training checkpoints are removed — verify that if you were to restart your training script now, it would start at epoch 0.
    • Make sure that git credentials are available on the system if you need to push to or pull from a remote git repository at any point. Use git config --global credential.helper store or inject your username and developer access token in the remote url, e.g.
    • Make sure that AWS credentials are available on the system (at ~/.aws/credentials) if you need to push to or pull from S3.

Launching a spot instance

  • Open the AWS EC2 service console.
  • Go to Spot Requests in the left sidebar menu, then click Request Spot Instances in the header. Request an EC2 Spot Instance
  • Select Load balancing workloads
  • Choose your AMI (click Search for AMI).
  • Ignore the Minimum Compute Unit settings.
  • Choose your key pair.
  • Expand the Additional Configurations section.
    • Leave the Delete column checked next to your volume.
    • Check the box for EBS-optimized instances if read/write performance is important to your application.
    • Add security groups for SSH, tensorboard, etc.
    • Add tags such as project=lulc, owner=collin
    • Find the User Data field at the bottom, which is where you enter the startup script for your instance.
      • An example startup script is provided in this gist (See user-data-template.txt and
      • Start by pasting the contents of user-data-template.txt as plain text.
      • Then scroll to the bottom and replace the contents of the userdata.txt file starting with #!/bin/bash and up to but not including --// with your customized bash script, based on the example provided in
      • Make sure the first line in your script is #!/bin/bash (or alternative).
      • Make sure the last line in your script is sudo shutdown. This ensures that the instance will shut down after your training is completed. When AWS attempts to restart or relaunch your instance, the training should hopefu
      • Make sure the last line in the textarea after your script is --//.
      • This clunky format will ensure that your script is run every time the instance is restarted. If you directly paste in the contents of without the template from user-data-template.txt, then your User Data script will only run when the instance is launched (for the first time).
  • Target capacity
    • How many spot instances do you want to keep running simultaneously? Typically the answer is 1, unless you're trying to do distributed training across multiple instances.
    • Check the box for Maintain target capacity.
    • Change interruption behavior to Stop for easier debugging via access to boot logs, which will contain any console logs during execution of your User Data script. Also useful in case there were any hiccups in persisting your checkpoints after training.
  • Fleet request settings
    • Uncheck Apply Recommendations
    • Remove all recommended instance types
    • Click Select Instance Types and add your preferred instance type(s).
    • For Fleet allocation strategy, choose Lowest price.
  • Click Launch button 🎉
  • What's next?
    • Monitor your spot requests in the Spot Requests section in the left sidebar of the EC2 console. Expand the spot request to see launched instances, or select it to see logs, status, and other details in the bottom pane.
    • Any instance(s) launched by your spot requests also appear in the Instances section in the EC2 console, but with limited control.
    • Check logs and/or your tensorboard server to make sure training starts successfully.


  • User Data documentation:
  • How can I execute user data to automatically run with every restart of my Amazon EC2 instance?
  • System Logs from User Data
    • First access your system logs:
      • From the EC2 console under Instances, right-click the instance, expand Instance Settings, choose System Logs.
      • SSH into the instance and find logs in /var/log/cloud-init-output.log.
      • From the command-line: `aws ec2 get-console-output --region us-east-1 --instance-id i-123 | python -c 'import sys, json; print json.load(sys.stdin)["Output"]'
    • Search them for the text login: which will take you to the console output from your User Data script.
  • Training not running?
    • Check system logs for errors in the startup script.
  • Instances are terminated?
    • Check spot request details for clues about why they were terminated.
    • A bug in your training script would cause the rest of your User Data script to be executed, which usually ends in a sudo shutdown command. So if it looks like the instance starts up then quickly shuts down again, this could be the problem. Check the system logs for error messages from your training script.
import itertools
from tqdm import tqdm
import torch
gpu_if_available = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def train_model(
model, dataloader, criterion, optimizer,
for i_epoch in range(start_epoch, max_epochs):
epoch_dataloader = itertools.islice(dataloader, max_epoch_batches) if max_epoch_batches else dataloader
for i_batch, (X, Y) in enumerate(tqdm(epoch_dataloader, total=max_epoch_batches)):
Y_pred, loss = train_model_batch(model, X, Y, optimizer, criterion=criterion, device=device)
if callable(after_epoch):
def train_model_batch(model, X, Y, optimizer, criterion, device=gpu_if_available):
# TODO Perform one forward and backward pass on the given model, batch, optimizer, criterion, and device.
if __name__ == '__main__':
dl = get_dataloader()
m = get_model()
o = get_optimizer(m)
l = get_criterion()
# Default training state
training_state = {
'completed_epoch': -1,
training_state = load_training_state('training-checkpoint', './checkpoints/', load_modules={
'model': m,
'optim': o,
print('Loaded training state.')
print('Warning: Failed to load any training state checkpoint. This is correct iff you expected to start a brand new training.')
next_epoch = training_state['completed_epoch'] + 1
def after_each_epoch(i_epoch):
# Save training checkpoint
'completed_epoch': i_epoch,
'model': m,
'optim': o,
# Evaluate model
if i_epoch % 3 == 0:
# TODO evaluate the model.
# TODO save checkpoint if its a nice one.
train_model(m, dl, l, o, start_epoch=next_epoch, max_epochs=200, max_epoch_batches=1000, after_epoch=after_each_epoch)
# Note: You are not logged-in as the usual `ubuntu` user — I believe you start logged-in as the root user.
# Login as user `ubuntu` so that file permissions behave as expected, and setup typical bash profile.
sudo su ubuntu
export HOME=/home/ubuntu
source ~/.profile
cd ~/path/to/project
# The experiment_name is used (1) to make sure we only checkout code and data one time (on first launch),
# and (2) to avoid any git conflicts if committing and pushing anything back to github after your
# experiment.
# If you only care about (1), then you don't need to change this value across experiments. If you care
# about (2), then experiment_name must be UNIQUE with respect to all existing git branches on github.
experiment_name = 'testing123'
if ! git status | grep "On branch $experiment_name$"; then
# This block only runs on first launch, not on restarts.
# Checkout your experiment code and data (if different from AMI snapshot).
git pull
# EITHER checkout a particular branch which contains your experiment code:
git checkout experiment-branch
git pull
# OR checkout a specific commit:
git checkout abc123
# Update project dependencies
pipenv install
# Download some datasets or other remote dependencies
pipenv run dvc pull
# We are going to create a new branch for running our experiment — you can always merge it into master later.
git checkout -b $experiment_name
git branch --set-upstream-to=origin/$experiment_name
# Run experiment
# python
pipenv run dvc repro train.dvc
# Commit & push experiment results
git commit -am "Training completed: $experiment_name"
git push
pipenv run dvc push
aws s3 sync ./path/to/logdir/ s3://bucketname/shared-tensorboard-logdir/project/$experiment_name/
sudo shutdown
Content-Type: multipart/mixed; boundary="//"
MIME-Version: 1.0
Content-Type: text/cloud-config; charset="us-ascii"
MIME-Version: 1.0
Content-Transfer-Encoding: 7bit
Content-Disposition: attachment; filename="cloud-config.txt"
- [scripts-user, always]
Content-Type: text/x-shellscript; charset="us-ascii"
MIME-Version: 1.0
Content-Transfer-Encoding: 7bit
Content-Disposition: attachment; filename="userdata.txt"
import torch
import json
import os
STATE_DICT_KEY = '__state_dict__'
def save_training_state(training_state, checkpoint_basename, checkpoints_dir):
training_state (dict): Arbitrary state information to be saved & loaded as json. Any values which are
instances of torch.nn.Module or torch.optim.Optimizer will have their
state_dicts be saved & loaded to separate files using
Writes multiple files to disk:
- `{checkpoint_basename}.json` containing the `training_state` dict with any Modules or Optimizers
replaced by a dict pointing to their separate saved filepaths
- `{checkpoint_basename}.{training_state_key}.pt` for each Module in `training_state` top-level values
- `{checkpoint_basename}.{training_state_key}.pt` for each Optimizer in `training_state` top-level values
checkpoints_dir = os.path.join(os.path.dirname(__file__), checkpoints_dir)
os.makedirs(checkpoints_dir, exist_ok=True)
pathto = lambda fname: os.path.join(checkpoints_dir, fname)
modules = {}
for k, v in training_state.items():
if isinstance(v, torch.nn.Module) or isinstance(v, torch.optim.Optimizer):
filename = f'{checkpoint_basename}.{k}.pt', pathto(filename))
modules[k] = {STATE_DICT_KEY: filename}
with open(pathto(f'{checkpoint_basename}.json'), 'w') as f:
json.dump(training_state, f)
def load_training_state(checkpoint_basename, checkpoints_dir, load_modules=None):
load_modules (dict): If you included any instances of torch.nn.Module or torch.optim.Optimizer when
saving the training state, and you include modules here with the same key, their
state_dict will be loaded for you and stripped from the returned training_state.
Otherwise, any saved Module or Optimizer state_dicts will be returned as part of
`training_state` with the same key as they were saved with.
Reads multiple files from disk:
- `{checkpoint_basename}.json` containing the `training_state` dict with any Modules or Optimizers
replaced by a dict pointing to their separate saved filepaths.
- `{filepath}` for each of `training_state` top-level values which contains a reference (see code
for format) to a separate filepath containing a saved state_dict.
Returns: training_state (dict) as previously saved using save_training_state. If any of the original
values passed to save_training_state were a torch.nn.Module or torch.optim.Optimizer, and
they are not provided in `load_modules`, they will now be raw state_dicts in the training
state, which you will need to load using `your_module.load_state_dict(training_state['original_key_when_saved'])`.
checkpoints_dir = os.path.join(os.path.dirname(__file__), checkpoints_dir)
pathto = lambda fname: os.path.join(checkpoints_dir, fname)
with open(pathto(f'{checkpoint_basename}.json'), 'r') as f:
training_state = json.load(f)
loaded_modules = []
for k, v in training_state.items():
if type(v) == dict and STATE_DICT_KEY in v:
# Read state dict from disk
training_state[k] = torch.load(pathto(v[STATE_DICT_KEY]), map_location=lambda storage, location: storage)
if type(load_modules) == dict and k in load_modules:
# Load state dict into provided module
# Remove state dict from returned training state
loaded_modules += [k]
for k in loaded_modules:
del training_state[k]
return training_state
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment