Skip to content

Instantly share code, notes, and snippets.

@rnag
Last active July 28, 2023 06:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rnag/db6bf83d9ca19dfe897d6ccabd4e2570 to your computer and use it in GitHub Desktop.
Save rnag/db6bf83d9ca19dfe897d6ccabd4e2570 to your computer and use it in GitHub Desktop.
Dataclass Type Validation
# dataclass type validation for fields on class instantiation -and- on assignment.
from dataclasses import dataclass
from validators import TypeValidator
@dataclass
class FooDC:
# alternatively, like:
# number: int = TypeValidator(default_factory=int)
number: int = TypeValidator()
word: str = TypeValidator()
foo = FooDC(number=3, word='1')
print(foo)
try:
_ = FooDC(number='test')
except TypeError as e:
print(e)
else:
raise ValueError('expected a TypeError to be raised!')
print()
bar = FooDC()
bar.word = 2
bar.word
# dataclass type validation for fields on class instantiation -and- on assignment.
#
# NOTE: this example should work in Python 3.7+
from dataclasses import dataclass
from typing import Dict, Optional, Union
from validators import TypeValidator
@dataclass
class FooTest:
map: Dict[str, float] = TypeValidator()
num_or_str: Union[int, float, str] = TypeValidator()
opt_word: Optional[str] = TypeValidator()
foo = FooTest()
print(foo)
try:
_ = FooTest(map={'key': 1.23}, num_or_str=b'byte string')
except TypeError as e:
print(e)
else:
raise ValueError('expected a TypeError to be raised!')
bar = FooTest()
bar.map = 2
# dataclass type validation for fields on class instantiation -and- on assignment.
#
# NOTE: this example should work in Python 3.10+
from dataclasses import dataclass
from validators import TypeValidator
@dataclass
class FooTest:
map: dict[str, float] = TypeValidator()
num_or_str: int | float | str = TypeValidator()
opt_word: str | None = TypeValidator()
foo = FooTest()
print(foo)
try:
_ = FooTest(map={'key': 1.23}, num_or_str=b'byte string')
except TypeError as e:
print(e)
else:
raise ValueError('expected a TypeError to be raised!')
bar = FooTest()
bar.map = 2
from dataclasses import MISSING
from typing import Callable, Generic, NewType, TypeVar, Union
_NoneType = type(None)
_T = TypeVar('_T')
_UNSET = NewType('_UNSET', None)
class TypeValidator(Generic[_T]):
__slots__ = ('add_default',
'private_name',
'default',
'default_factory',
'type',
)
def __init__(self, *,
add_default=True,
default: _T = MISSING,
default_factory: Callable[[], _T] = MISSING):
self.add_default = add_default
self.default = default
self.default_factory = default_factory
def __set_name__(self, owner, name):
self.private_name = '_' + name
tp = owner.__annotations__[name]
try:
# check for types like `typing.Dict` and `typing.List`.
tps = tp.__origin__
# fix for Python 3.7, where `typing.Union` has an `__origin__` attribute
if tps is Union:
tps = tp.__args__
except AttributeError:
# check for types like `typing.Union[int, str]`
# and `typing.Optional[float]`.
tps = getattr(tp, '__args__', tp)
self.type = tps
if self.add_default and self.default is self.default_factory is MISSING:
if isinstance(tps, tuple): # a tuple of types
if _NoneType in tps:
# if we see a `None` for a type, then `None` is a reasonable
# default to use.
self.default = None
return
else:
# fix, since tuples aren't callable in any case.
tps = tps[0]
# check if the type can be used as a "default factory"
try:
_ = tps()
except TypeError:
pass
else:
self.default_factory = tps
def __get__(self, obj, objtype=None):
if obj is None:
# we are called from the `@dataclass` decorator, which processes
# the class (objtype here).
return self.default if self.default_factory is MISSING else _UNSET
# we are called from the class instance, to retrieve the attribute.
return getattr(obj, self.private_name)
def __set__(self, obj, value):
if value is _UNSET:
value = self.default_factory()
else:
self.validate(value)
setattr(obj, self.private_name, value)
def validate(self, value):
if not isinstance(value, self.type):
msg = f'Expected {self.private_name.lstrip("_")} to be {self.type!r}, got {value!r}'
raise TypeError(msg)
@srbharadwaj
Copy link

Nice.. thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment