Skip to content

Instantly share code, notes, and snippets.

@maharjun
Created May 4, 2023 06:44
Show Gist options
  • Save maharjun/48b4c583572ecaf655a52fd56b420f9b to your computer and use it in GitHub Desktop.
Save maharjun/48b4c583572ecaf655a52fd56b420f9b to your computer and use it in GitHub Desktop.
Torch Dill Shim
"""
This is a shim for dill to be used with torch (namely that when used in a project
that pickles torch objects, dill should be imported from this module).
for example::
from utils.dillshim import dill
The purpose of this shim is register the pickling and unpickling logic
for certain native pytorch types such as torch random generators that
otherwise cannot be pickled by dill, as well as to be able to unpickle
objects that were created in different devices
"""
###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2023, maharjun
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
###############################################################################
import torch
import io
import dill
def get_device_name(device: torch.device):
"""
Get name of specified device object.
"""
assert isinstance(device, torch.device), "device must be a torch.device"
if device.index:
return f'{device.type}:{device.index}'
else:
return device.type
def _recreate_generator(gen_state: torch.Tensor, gen_device):
return_gen: torch.Generator = torch.Generator(device=gen_device)
return_gen.set_state(gen_state)
return return_gen
@dill.register(torch.Generator)
def _save_generator(pickler, gen):
return pickler.save_reduce(_recreate_generator, (gen.get_state(), get_device_name(gen.device)), obj=gen)
class device_unpickler(dill.Unpickler):
"""
This is an extension of the dill unpickler that unpickles tensors onto the device specified in the member variable device.
Examples
--------
One can set the device in the class member `device` and unpickle a file as below::
from utils.generic.dillshim import device_unpickler
device_unpickler.device = torch.device('cpu')
with open('pickle_file.p', 'rb') as fin:
values = device_unpickler(fin).load()
One may also set the device for each instance of the device_unpickler as follows::
from utils.generic.dillshim import device_unpickler
with open('pickle_file.p', 'rb') as fin:
unpickler = device_unpickler(fin)
unpickler.device = torch.device('cpu')
values = unpickler.load()
"""
device = None
def find_class(self, module, name):
if self.device is not None and module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location=get_device_name(self.device))
else: return super().find_class(module, name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment