Skip to content

Instantly share code, notes, and snippets.

@grey-area
Created February 6, 2021 15:03
Show Gist options
  • Save grey-area/a3f6171377d122472d2bb0babfccce28 to your computer and use it in GitHub Desktop.
Save grey-area/a3f6171377d122472d2bb0babfccce28 to your computer and use it in GitHub Desktop.
import torch
from pathlib import Path
import platform
from datetime import datetime
import git
from pip._internal.operations import freeze
def save_checkpoint(checkpoint_path, model, optimizer, config_dict):
run_info = {
'date': datetime.now().strftime('%Y-%m-%d'),
'directory': str(Path('.').resolve()),
'hostname': platform.node()
}
metadata_dict = {
'config': config_dict,
'run': run_info,
'pip': list(freeze.freeze())
}
try:
repo = git.Repo(search_parent_directories=True)
head_obj = repo.head.object
git_info = {
'commit': head_obj.hexsha,
'branch': repo.active_branch.name,
'author': {
'name': head_obj.author.name,
'email': head_obj.author.email
}
}
try:
git_info['remotes'] = list(repo.remote().urls),
except ValueError:
pass
metadata_dict['git'] = git_info
except (git.exc.InvalidGitRepositoryError, ValueError):
pass
checkpoint_dict = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'metadata': metadata_dict
}
torch.save(checkpoint_dict, checkpoint_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment