Skip to content

Instantly share code, notes, and snippets.

@Lunderberg
Created June 28, 2021 21:56
Show Gist options
  • Save Lunderberg/039f2f1d14c9ad7de88e1d2224a430cc to your computer and use it in GitHub Desktop.
Save Lunderberg/039f2f1d14c9ad7de88e1d2224a430cc to your computer and use it in GitHub Desktop.
Deep Copy, but only allow classes designed to work with deepcopy
#!/usr/bin/env python3
import copy
import copyreg
import ctypes
import functools
# First implementation, failed because copy implementations
# (e.g. copy._deepcopy_list) already have a reference to
# copy.deepcopy, and use that instead of our patched version.
def _monkeypatch_deepcopy():
allowed_classes = []
orig = copy.deepcopy
@functools.wraps(orig)
def wrapper(x, *args, **kwargs):
print("Deepcopying", x)
cls = type(x)
if (
cls in copy._deepcopy_dispatch
or issubclass(cls, type)
or getattr(x, "__deepcopy__", None)
or copyreg.dispatch_table.get(cls)
or cls.__reduce__ is not object.__reduce__
or cls.__reduce_ex__ is not object.__reduce_ex__
or cls in allowed_classes
):
return orig(x, *args, **kwargs)
raise TypeError(
(
"Cannot copy fixture of type {}. TVM fixture caching "
"is limited to objects that explicitly provide the ability "
"to be copied (e.g. through __deepcopy__, __getstate__, or __setstate__),"
"and forbids the use of the default `object.__reduce__` and "
"`object.__reduce_ex__`. For classes that are known to be "
"safe to use with copy.deepcopy, please add the class to the "
"`allowed_classes` list in tvm.testing._monkeypatch_deepcopy"
).format(cls.__name__)
)
copy.deepcopy = wrapper
class _DeepCopyAllowedClasses(dict):
def __init__(self, *allowed_class_list):
self.allowed_class_list = allowed_class_list
super().__init__()
def get(self, key, *args, **kwargs):
obj = ctypes.cast(key, ctypes.py_object).value
cls = type(obj)
if (
cls in copy._deepcopy_dispatch
or issubclass(cls, type)
or getattr(obj, "__deepcopy__", None)
or copyreg.dispatch_table.get(cls)
or cls.__reduce__ is not object.__reduce__
or cls.__reduce_ex__ is not object.__reduce_ex__
or cls in self.allowed_class_list
):
return super().get(key, *args, **kwargs)
raise TypeError(
(
"Cannot copy fixture of type {}. TVM fixture caching "
"is limited to objects that explicitly provide the ability "
"to be copied (e.g. through __deepcopy__, __getstate__, or __setstate__),"
"and forbids the use of the default `object.__reduce__` and "
"`object.__reduce_ex__`. For classes that are known to be "
"safe to use with copy.deepcopy, please add the class to "
"`allowed_class_list` in tvm.testing._fixture_cache"
).format(cls.__name__)
)
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def main():
obj = [Point(1, 2)]
print("Deep-copy 1: ", copy.deepcopy(obj))
print("Deep-copy 2: ", copy.deepcopy(obj, _DeepCopyAllowedClasses(Point)))
print("Deep-copy 3: ", copy.deepcopy(obj, _DeepCopyAllowedClasses()))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment