Last active
April 17, 2024 14:17
-
-
Save DavidCEllis/df74d4b181e4f11e74e9072608058e57 to your computer and use it in GitHub Desktop.
Example preprocessor for dataclasses to use Annotated[...] values instead of Field.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from copy import copy | |
from dataclasses import dataclass, field, fields, Field | |
from typing import Annotated, Any, ClassVar, get_origin | |
# Modifying objects | |
class FieldModifier: | |
__slots__ = ("modifiers", ) | |
modifiers: dict[str, Any] | |
def __init__(self, **modifiers): | |
self.modifiers = modifiers | |
def __repr__(self): | |
mod_args = ", ".join(f"{k}={v!r}" for k, v in self.modifiers.items()) | |
return ( | |
f"{type(self).__name__}({mod_args})" | |
) | |
def __eq__(self, other): | |
if self.__class__ == other.__class__: | |
return self.modifiers == other.modifiers | |
return NotImplemented | |
KW_ONLY_MOD = FieldModifier(kw_only=True) # don't want to clash with dataclasses.KW_ONLY | |
NO_INIT = FieldModifier(init=False) | |
NO_REPR = FieldModifier(repr=False) | |
NO_COMPARE = FieldModifier(compare=False) | |
IGNORE_ALL = FieldModifier(init=False, repr=False, compare=False) | |
def preprocess(cls): | |
# Need to actually use the type hints, so can't have them as strings. | |
hints = inspect.get_annotations(cls, eval_str=True) | |
cls_dict = cls.__dict__ | |
# __slots__ defined before dataclasses means the fields can't have values | |
# We need to assign values for the preprocessing to work. | |
if "__slots__" in cls_dict: | |
raise TypeError( | |
"Preprocessing does not work on classes with '__slots__' already defined" | |
) | |
for key, anno in hints.items(): | |
if get_origin(anno) is Annotated: | |
modifiers = {} | |
meta = anno.__metadata__ | |
for v in meta: | |
if isinstance(v, FieldModifier): | |
modifiers.update(v.modifiers) | |
if modifiers: | |
if key in cls_dict: | |
val = cls_dict[key] | |
if isinstance(val, Field): | |
# Field might not only be used in one place | |
# so copy first | |
new_val = copy(val) | |
for k, v in modifiers.items(): | |
setattr(new_val, k, v) | |
else: | |
new_val = field(default=val, **modifiers) | |
else: | |
new_val = field(**modifiers) | |
setattr(cls, key, new_val) | |
return cls | |
# You can combine these into one decorator if desired | |
@dataclass | |
@preprocess | |
class X: | |
x: str | |
y: ClassVar[str] = "This is okay" | |
a: Annotated[str, NO_INIT] = "Not In __init__ signature" | |
b: Annotated[str, NO_REPR] = "Not In Repr" | |
c: Annotated[list[str], NO_COMPARE] = field(default_factory=list) | |
d: Annotated[str, IGNORE_ALL] = field(default="Not Anywhere") | |
e: Annotated[str, KW_ONLY_MOD, NO_COMPARE] | |
ex = X("Value of x", e="Value of e") | |
ex2 = X("Value of x", c=["a", "b"], e="Uncompared Value of e") | |
# b and d not in repr | |
print(ex) | |
print(ex2) | |
# Values are different but excluded from EQ | |
print(f"{ex.d == ex2.d = }") | |
print(f"{ex == ex2 = }") | |
# d is also excluded | |
ex2.d = "changed" | |
print(f"{ex.d == ex2.d = }") | |
print(f"{ex == ex2 = }") | |
# a is included | |
ex2.a = "changed" | |
print(f"{ex.a == ex2.a = }") | |
print(f"{ex == ex2 = }") | |
# a is not in __init__ signature | |
try: | |
ex3 = X("value of x", a="value of a", e="value of e") | |
except TypeError as e: | |
print(e) | |
print("\nname init repr compare kw_only") | |
for f in fields(X): | |
print(f.name, f.init, f.repr, f.compare, f.kw_only) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment