Skip to content

Instantly share code, notes, and snippets.

@DavidCEllis
Last active March 20, 2023 16:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DavidCEllis/7f0b3aacea62cded147174b63ee04233 to your computer and use it in GitHub Desktop.
Save DavidCEllis/7f0b3aacea62cded147174b63ee04233 to your computer and use it in GitHub Desktop.
Modification of dataclasses' `asdict` function to skip copying objects where deepcopy(obj) is obj
import copy
import types
from dataclasses import _is_dataclass_instance, fields
_ATOMIC_TYPES = {
types.NoneType,
types.EllipsisType,
types.NotImplementedType,
int,
float,
bool,
complex,
bytes,
str,
types.CodeType,
type,
range,
types.BuiltinFunctionType,
types.FunctionType,
# weakref.ref, # weakref is not currently imported by dataclasses directly
property,
}
def asdict(obj, *, dict_factory=dict):
"""Return the fields of a dataclass instance as a new dictionary mapping
field names to field values.
Example usage::
@dataclass
class C:
x: int
y: int
c = C(1, 2)
assert asdict(c) == {'x': 1, 'y': 2}
If given, 'dict_factory' will be used instead of built-in dict.
The function applies recursively to field values that are
dataclass instances. This will also look into built-in containers:
tuples, lists, and dicts.
"""
if not _is_dataclass_instance(obj):
raise TypeError("asdict() should be called on dataclass instances")
return _asdict_inner(obj, dict_factory)
def _asdict_inner(obj, dict_factory):
if _is_dataclass_instance(obj):
if dict_factory is dict:
result = {}
for f in fields(obj):
value = getattr(obj, f.name)
result[f.name] = value if type(value) in _ATOMIC_TYPES else _asdict_inner(value, dict_factory)
return result
else:
result = []
for f in fields(obj):
value = getattr(obj, f.name)
if type(value) not in _ATOMIC_TYPES:
value = _asdict_inner(value, dict_factory)
result.append((f.name, value))
return dict_factory(result)
elif isinstance(obj, tuple) and hasattr(obj, '_fields'):
# obj is a namedtuple. Recurse into it, but the returned
# object is another namedtuple of the same type. This is
# similar to how other list- or tuple-derived classes are
# treated (see below), but we just need to create them
# differently because a namedtuple's __init__ needs to be
# called differently (see bpo-34363).
# I'm not using namedtuple's _asdict()
# method, because:
# - it does not recurse in to the namedtuple fields and
# convert them to dicts (using dict_factory).
# - I don't actually want to return a dict here. The main
# use case here is json.dumps, and it handles converting
# namedtuples to lists. Admittedly we're losing some
# information here when we produce a json list instead of a
# dict. Note that if we returned dicts here instead of
# namedtuples, we could no longer call asdict() on a data
# structure where a namedtuple was used as a dict key.
return type(obj)(*[
v if type(v) in _ATOMIC_TYPES else _asdict_inner(v, dict_factory)
for v in obj
])
elif isinstance(obj, (list, tuple)):
# Assume we can create an object of this type by passing in a
# generator (which is not true for namedtuples, handled
# above).
return type(obj)(
v if type(v) in _ATOMIC_TYPES else _asdict_inner(v, dict_factory)
for v in obj
)
elif isinstance(obj, dict):
return type(obj)(
(k if type(k) in _ATOMIC_TYPES else _asdict_inner(k, dict_factory),
v if type(v) in _ATOMIC_TYPES else _asdict_inner(v, dict_factory))
for k, v in obj.items())
else:
return copy.deepcopy(obj)
def astuple(obj, *, tuple_factory=tuple):
"""Return the fields of a dataclass instance as a new tuple of field values.
Example usage::
@dataclass
class C:
x: int
y: int
c = C(1, 2)
assert astuple(c) == (1, 2)
If given, 'tuple_factory' will be used instead of built-in tuple.
The function applies recursively to field values that are
dataclass instances. This will also look into built-in containers:
tuples, lists, and dicts.
"""
if not _is_dataclass_instance(obj):
raise TypeError("astuple() should be called on dataclass instances")
return _astuple_inner(obj, tuple_factory)
def _astuple_inner(obj, tuple_factory):
if _is_dataclass_instance(obj):
result = []
for f in fields(obj):
value = getattr(obj, f.name)
if type(value) not in _ATOMIC_TYPES:
value = _astuple_inner(value, tuple_factory)
result.append(value)
return tuple_factory(result)
elif isinstance(obj, tuple) and hasattr(obj, '_fields'):
# obj is a namedtuple. Recurse into it, but the returned
# object is another namedtuple of the same type. This is
# similar to how other list- or tuple-derived classes are
# treated (see below), but we just need to create them
# differently because a namedtuple's __init__ needs to be
# called differently (see bpo-34363).
return type(obj)(*[
v if type(v) in _ATOMIC_TYPES else _astuple_inner(v, tuple_factory)
for v in obj
])
elif isinstance(obj, (list, tuple)):
# Assume we can create an object of this type by passing in a
# generator (which is not true for namedtuples, handled
# above).
return type(obj)(
v if v in _ATOMIC_TYPES else _astuple_inner(v, tuple_factory)
for v in obj
)
elif isinstance(obj, dict):
return type(obj)(
(k if k in _ATOMIC_TYPES else _astuple_inner(k, tuple_factory),
v if v in _ATOMIC_TYPES else _astuple_inner(v, tuple_factory))
for k, v in obj.items())
else:
return copy.deepcopy(obj)
def norecurse_asdict(obj):
if _is_dataclass_instance(obj):
return {f.name: getattr(obj, f.name) for f in fields(obj)}
raise TypeError("asdict() should be called on dataclass instances")
def atomic_specialcase_test():
import dataclasses
from timeit import timeit
ITERATIONS = 500_000
# Best case - everything skips deepcopy
@dataclasses.dataclass
class AtomicExample:
p: str = "usr/bin/python"
major: int = 3
minor: int = 11
installed: bool = True
atomic_ex = AtomicExample()
new = timeit(lambda: asdict(atomic_ex), number=ITERATIONS)
current = timeit(lambda: dataclasses.asdict(atomic_ex), number=ITERATIONS)
norecurse = timeit(lambda: norecurse_asdict(atomic_ex), number=ITERATIONS)
print(f"Best case asdict:\n{current=:.2f}s\n{new=:.2f}s")
print(f"{norecurse=:.2f}s")
print(f"New method takes {(new / current) * 100:.0f}% of the time\n")
new = timeit(lambda: astuple(atomic_ex), number=ITERATIONS)
current = timeit(lambda: dataclasses.astuple(atomic_ex), number=ITERATIONS)
print(f"Best case astuple:\n{current=:.2f}s\n{new=:.2f}s")
print(f"New method takes {(new / current) * 100:.0f}% of the time\n")
# Worst case - everything has to be deepcopied
ITERATIONS = 50_000 # This one is much slower so 10x fewer iterations
class PyVer:
def __init__(self, major=3, minor=11):
self.major = major
self.minor = minor
def __hash__(self):
return hash((self.major, self.minor))
def __eq__(self, other):
if self.__class__ == other.__class__:
return (self.major, self.minor) == (other.major, other.minor)
return NotImplemented
@dataclasses.dataclass
class WorstExample:
v311: PyVer = PyVer()
v310: PyVer = PyVer(3, 10)
v309: PyVer = PyVer(3, 9)
v27: PyVer = PyVer(2, 7)
non_atomic_ex = WorstExample()
new = timeit(lambda: asdict(non_atomic_ex), number=ITERATIONS)
current = timeit(lambda: dataclasses.asdict(non_atomic_ex), number=ITERATIONS)
print(f"Worst case asdict:\n{current=:.2f}s\n{new=:.2f}s")
print(f"New method takes {(new / current) * 100:.0f}% of the time\n")
new = timeit(lambda: astuple(non_atomic_ex), number=ITERATIONS)
current = timeit(lambda: dataclasses.astuple(non_atomic_ex), number=ITERATIONS)
print(f"Worst case astuple:\n{current=:.2f}s\n{new=:.2f}s")
print(f"New method takes {(new / current) * 100:.0f}% of the time\n")
assert asdict(atomic_ex) == dataclasses.asdict(atomic_ex)
assert asdict(non_atomic_ex) == dataclasses.asdict(non_atomic_ex)
assert astuple(atomic_ex) == dataclasses.astuple(atomic_ex)
assert astuple(non_atomic_ex) == dataclasses.astuple(non_atomic_ex)
def bigtest():
import dataclasses
from timeit import timeit
# Extended from the tests ORJSON uses to claim how slow
# json is for dataclasses using asdict
ITERATIONS = 50
class Spam:
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self) -> str:
return f"Spam({self.x!r}, {self.y!r})"
def __eq__(self, other):
if self.__class__ == other.__class__:
return (self.x, self.y) == (other.x, other.y)
return NotImplemented
def __hash__(self):
return hash((self.x, self.y))
@dataclasses.dataclass
class Eggs:
x: int
y: str
@dataclasses.dataclass
class Member:
id: int
active: bool
@dataclasses.dataclass
class Object:
id: int
name: str
members: list[Member]
spam: Spam
eggs: Eggs
@dataclasses.dataclass
class ObjectHolder:
objects: dict[str, Object]
objects_as_dataclass = {
f"object_{i}": Object(
i,
str(i) * 3,
[Member(j, True) for j in range(0, 10)],
Spam(i, str(i)),
Eggs(i, str(i))
)
for i in range(100000, 101000)
}
objectholder = ObjectHolder(objects_as_dataclass)
assert asdict(objectholder) == dataclasses.asdict(objectholder)
current = timeit(
lambda: dataclasses.asdict(objectholder),
number=ITERATIONS
)
new = timeit(
lambda: asdict(objectholder),
number=ITERATIONS
)
print("Mixed case asdict:")
print(f"{current=:.2f}\n{new=:.2f}")
print(f"New method takes {(new / current) * 100:.0f}% of the time")
if __name__ == "__main__":
atomic_specialcase_test()
bigtest()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment