Skip to content

Instantly share code, notes, and snippets.

@colesbury
Created May 31, 2017 16:45
Show Gist options
  • Save colesbury/5613e009b7abd6975dc1c2198a43d707 to your computer and use it in GitHub Desktop.
Save colesbury/5613e009b7abd6975dc1c2198a43d707 to your computer and use it in GitHub Desktop.
import torch
import sys
import subprocess
import tempfile
import tinys3
import shutil
import os
def main():
inp = sys.argv[1]
checkpoint = torch.load(inp)
print('epoch', checkpoint['epoch'])
state_dict = checkpoint['state_dict']
s = type(state_dict)()
for k, v in state_dict.items():
k = k.replace('features.module', 'features')
k = k.replace('module.', '')
s[k] = v.cpu()
S3_ACCESS_KEY = os.environ['S3_ACCESS_KEY']
S3_SECRET_KEY = os.environ['S3_SECRET_KEY']
conn = tinys3.Connection(S3_ACCESS_KEY, S3_SECRET_KEY, tls=True)
with tempfile.NamedTemporaryFile() as f:
torch.save(s, f)
f.flush()
f.seek(0)
sha256 = subprocess.check_output(['sha256sum', f.name])
base_name = os.path.splitext(os.path.basename(inp))[0]
name = '{}-{}.pth'.format(base_name, sha256[:8].decode('utf-8'))
key = 'models/{}'.format(name)
print('Uploading to', key)
conn.upload(key, f, 'pytorch', public=True)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment