Created
April 21, 2019 22:25
-
-
Save mkroutikov/745d5f2f2139bb7caff2224c87957ae5 to your computer and use it in GitHub Desktop.
Write PyTorch models to cloud locations (requires tensorflow)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Utilities to save/load torch models. | |
Same as built-in torch.save/torch.load, but support cloud URLs | |
''' | |
import tensorflow.gfile as gio | |
import io | |
import torch | |
def save(object, fname): | |
memfile = io.BytesIO() | |
torch.save(object, memfile) | |
with gio.Open(fname, 'wb') as f: | |
f.write(memfile.getvalue()) | |
def load(fname, map_location=None): | |
with gio.Open(fname, 'rb') as f: | |
memfile = io.BytesIO(f.read()) | |
return torch.load(memfile, map_location=map_location) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment