Skip to content

Instantly share code, notes, and snippets.

@annawoodard
Last active September 6, 2023 20:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save annawoodard/60722ca79755f47d49aad63593c5b85e to your computer and use it in GitHub Desktop.
Save annawoodard/60722ca79755f47d49aad63593c5b85e to your computer and use it in GitHub Desktop.
pytorch checkpointing
def restart_from_checkpoint(checkpoint_path, restore_objects=None, **kwargs):
"""
Re-start training or inference from a previous checkpoint.
Args:
checkpoint_path (str): Path to checkpoint file
restore_objects (dict): Dict containing objects to reload from checkpoint
**kwargs (dict): Keyword args containing model states to reload
Returns:
None
Example:
# run once to create checkpoint; run again to load checkpoint
import torch
model = torch.nn.Linear(10, 5)
optimizer = torch.optim.Adam(model.parameters())
num_epochs = 10
to_restore = {"epoch": 0}
# if the checkpoint does not exist, this is a no-op
restart_from_checkpoint(
"checkpoint.pth", restore_objects=to_restore, model=model, optimizer=optimizer
)
start_epoch = to_restore["epoch"]
for epoch in range(start_epoch, num_epochs):
# load data, move to GPU, pass through model, calculate loss, step optimizer, etc.
checkpoint = {
"epoch": epoch + 1,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, "checkpoint.pth")
"""
if checkpoint_path is None or not os.path.isfile(checkpoint_path):
return
logger.info(f"Found checkpoint at {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# load states from checkpoint
for key, model in kwargs.items():
if key in checkpoint and model is not None:
try:
msg = model.load_state_dict(checkpoint[key], strict=False)
logger.info(
f"Loaded '{key}' from checkpoint '{checkpoint_path}' with msg {msg}"
)
except TypeError:
msg = model.load_state_dict(checkpoint[key])
logger.info(f"Loaded '{key}' from checkpoint '{checkpoint_path}'")
except ValueError:
logger.warn(
f"Failed to load '{key}' from checkpoint '{checkpoint_path}'"
)
else:
logger.info(f"Key '{key}' not found in checkpoint '{checkpoint_path}'")
# reload important variables
if restore_objects is not None:
for var_name in restore_objects:
if var_name in checkpoint:
restore_objects[var_name] = checkpoint[var_name]
logger.info(f"Loaded '{var_name}' from checkpoint '{checkpoint_path}'")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment