Skip to content

Instantly share code, notes, and snippets.

@mkroutikov
Created April 21, 2019 22:25
Show Gist options
  • Save mkroutikov/745d5f2f2139bb7caff2224c87957ae5 to your computer and use it in GitHub Desktop.
Save mkroutikov/745d5f2f2139bb7caff2224c87957ae5 to your computer and use it in GitHub Desktop.
Write PyTorch models to cloud locations (requires tensorflow)
'''
Utilities to save/load torch models.
Same as built-in torch.save/torch.load, but support cloud URLs
'''
import tensorflow.gfile as gio
import io
import torch
def save(object, fname):
memfile = io.BytesIO()
torch.save(object, memfile)
with gio.Open(fname, 'wb') as f:
f.write(memfile.getvalue())
def load(fname, map_location=None):
with gio.Open(fname, 'rb') as f:
memfile = io.BytesIO(f.read())
return torch.load(memfile, map_location=map_location)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment