Skip to content

Instantly share code, notes, and snippets.

@akaanirban
Created January 24, 2021 16:16
Show Gist options
  • Save akaanirban/11e338b359f07b50f2b27099584a0095 to your computer and use it in GitHub Desktop.
Save akaanirban/11e338b359f07b50f2b27099584a0095 to your computer and use it in GitHub Desktop.
How to read a pickled collection (list or dictionary etc.) of pytorch cuda tensor in cpu

What if you saved some loss values / accuracy values as a list of pytorch tensor in a system with cuda and then trying to plot the losses in a system with no GPU?

With some googling I found that the following code from (pytorch/pytorch#16797 (comment)) works fine! You just need to define the custome unpickler and use it in place of pickle.load!

import io
import torch
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else: return super().find_class(module, name)

if __name__ == "__main__":
  file_path = "/something/something.pkl"
  file_obj = open(file_path, "rb")
  file_obj.seek(0)
  # instead of pickle.load(file_obj) we do the following
  unpickled_stuff = CPU_Unpickler(file_obj).load()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment