Skip to content

Instantly share code, notes, and snippets.

@mementum
Created April 22, 2024 08:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mementum/707c570d1eccf7377d6289caa1e2d202 to your computer and use it in GitHub Desktop.
Save mementum/707c570d1eccf7377d6289caa1e2d202 to your computer and use it in GitHub Desktop.
Dataclass with Field Annotations using @
#!/usr/bin/env python
# -*- coding: utf-8; py-indent-offset:4 -*-
###############################################################################
from __future__ import annotations
from collections.abc import Callable
import re
import inspect
from typing import Annotated, overload
# dataclasses imports
import dataclasses
# Imports meant for re-export - ignore non-used and values cannot be determined
from dataclasses import * # noqa: F403 F401
# Specific imports for development error-checking
from dataclasses import dataclass as _dataclass, field, KW_ONLY, MISSING
__all__ = [
"at_dataclass",
"NO_INIT",
"NO_INIT_FACTORY",
"NO_FIELD",
] + dataclasses.__all__
class _NO_FIELD_TYPE:
pass
NO_FIELD = _NO_FIELD_TYPE()
class _NO_INIT_TYPE:
pass
NO_INIT = _NO_INIT_TYPE()
class _NO_INIT_FACTORY_TYPE:
pass
NO_INIT_FACTORY = _NO_INIT_FACTORY_TYPE()
ANNOTATOR = "@"
ANN_LBRACKET = "["
ANN_RE = re.compile('''\|(?=(?:[^'"]|'[^']*'|"[^"]*")*$)''')
@overload
def ann_dataclass(cls: type, **kwargs) -> type:
...
@overload
def at_dataclass(
cls: type | None = None,
reannotate: bool = True,
keep_dc_ann: bool = True,
annotator: str = ANNOTATOR,
) -> Callable[[type], type]:
...
def at_dataclass(
cls: type | None = None,
reannotate: bool = True,
keep_dc_ann: bool = True,
annotator: str = ANNOTATOR,
**kwargs,
) -> type | Callable[[type], type]:
# actual decorator for when cls is not None
def _annotifier(cls: type) -> type:
no_fields = {} # keep track of no_fields to remove and readd their annotations
# go over all annotations
for name, annotation in inspect.get_annotations(cls).items():
if not (type(annotation) is str):
continue # only parsing annotations which are in str format
try: # try a split [type, rest] from "type @
_type, f_ann = annotation.split(annotator, maxsplit=1)
except ValueError:
continue # splitting was not possible, nothing after the type
else:
_type = _type.rstrip() # remove trailing whitespace
f_ann = f_ann.lstrip() # remove leading whitespace
if f_ann.startswith(ANN_LBRACKET):
subannotations = eval(f_ann)
else:
f_tokens = ANN_RE.split(f_ann)
subannotations = eval(f"[{','.join(f_tokens)}]")
if NO_FIELD in subannotations: # remove from annotations
cls.__annotations__.pop(name)
if not keep_dc_ann:
subannotations.remove(NO_FIELD)
no_fields[name] = _type, subannotations # store for later re-adding
else:
f_kwargs = {}
defval = getattr(cls, name, MISSING)
if NO_INIT in subannotations:
f_kwargs["init"] = False
f_kwargs["default"] = defval
if not keep_dc_ann:
subannotations.remove(NO_INIT)
elif NO_INIT_FACTORY in subannotations:
f_kwargs["init"] = False
f_kwargs["default_factory"] = defval
if not keep_dc_ann:
subannotations.remove(NO_INIT_FACTORY)
elif KW_ONLY in subannotations:
f_kwargs["kw_only"] = True
f_kwargs["default"] = defval
if not keep_dc_ann:
subannotations.remove(KW_ONLY)
if f_kwargs:
setattr(cls, name, field(**f_kwargs))
if not reannotate or not subannotations:
cls.__annotations__[name] = _type
else:
cls.__annotations__[name] = Annotated[_type, *subannotations]
dataclassed = _dataclass(cls, **kwargs) # apply std dataclass processing
# restore no_field attributes to the annotations
for name, (_type, subannotations) in no_fields.items():
if not reannotate or not subannotations:
cls.__annotations__[name] = _type
else:
cls.__annotations__[name] = Annotated[_type, *subannotations]
return dataclassed
# decorator functionality when kwargs are used, return real deco (with closure)
if cls is None:
return _annotifier # -> Callable[[type], type]
# A cls is there, process it
return _annotifier(cls) # -> type
# With everything done export ann_dataclass as dataclass
dataclass = at_dataclass
# Small test
if __name__ == "__main__":
from dataclasses import fields
from typing import ClassVar
class Dummy:
pass
@at_dataclass
class A:
cv: ClassVar[str] = "classvar"
a: int
b: int @ KW_ONLY = 25
c: int @ NO_INIT = 5
d: list[str] @ NO_INIT_FACTORY = list
e: int @ NO_INIT | Dummy() | Dummy() = 0
f: int @ [NO_INIT, Dummy(), Dummy()] = 1
g: int @ NO_FIELD = 7
h: int @ NO_FIELD
# ############
a = A(3)
print("=" * 80)
print(f"{a.__annotations__ = }")
print("=" * 80)
print(f"{a.a = }")
print(f"{a.b = }")
for f in fields(A):
print("-- " + "-" * 70)
print(f"{f = }")
print("-" * 70)
try:
b = A(1, b=2)
except Exception as e:
print(f"Exception: {e = }")
else:
print("b is a keyword argument. Ok")
try:
b = A(1, 2)
except Exception as e:
print(f"Exception: {e = }")
try:
b = A(1, c=2)
except Exception as e:
print(f"Exception: {e = }")
try:
b = A(1, d=2)
except Exception as e:
print(f"Exception: {e = }")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment