Skip to content

Instantly share code, notes, and snippets.

@Zomatree
Last active December 28, 2022 22:49
Show Gist options
  • Save Zomatree/5380dab7fd2bf9b5e7fab331cdd1f79e to your computer and use it in GitHub Desktop.
Save Zomatree/5380dab7fd2bf9b5e7fab331cdd1f79e to your computer and use it in GitHub Desktop.
from __future__ import annotations
from dataclasses import dataclass
from types import NoneType, UnionType
from typing import Annotated, Any, Callable, Generic, Iterable, TypeGuard, TypeVar, Union, cast, get_origin as _get_origin, get_args, Self
T = TypeVar("T")
# isistance doesnt allow the second paramter to be generic so we must get the origin, however typing.get_origin doesnt return the class if its not generic
# so we default back to the type ourself
# isinstance(..., list[str]) -> Error (cant use isinstance with generics)
# isinstance(..., typing.get_origin(list[str])) -> Works (get_origin returns list)
# isinstance(..., typing.get_origin(list)) -> Error (get_origin returned None)
def get_origin(ty: Any) -> Any:
return _get_origin(ty) or ty
class Type(Generic[T]):
def __init__(self):
self._type: type[T] | None = None
self._literals: Iterable[T] | None = None
self._validator: Callable[[T], bool] = lambda v: True
self._type_checker: Callable[[T, type[T]], bool] = lambda obj, ty: isinstance(obj, get_origin(ty))
self._rename: str | None = None
self._key: str | None = None
self._default: Callable[[], T] | None = None
def values(self, literals: Iterable[T]) -> Self:
self._literals = literals
return self
def rename(self, name: str) -> Self:
self._rename = name
return self
def validator(self, f: Callable[[T], bool]) -> Self:
self._validator = f
return self
def default(self, f: Callable[[], T]) -> Self:
self._default = f
return self
def type_checker(self, f: Callable[[T, type], bool]) -> Self:
self._type_checker = f
return self
def _to_internal(self) -> InternalType[T]:
return InternalType(cast(type[T], self._type), self._literals, self._validator, self._rename, cast(str, self._key), self._default, self._type_checker, ())
def evaluate_annotation(cls: type, annot: Any) -> Any:
if not isinstance(annot, str):
return annot
return eval(annot)
@dataclass
class InternalType(Generic[T]):
type: type[T]
literals: Iterable[T] | None
validator: Callable[[T], bool]
rename: str | None
key: str
default: Callable[[], T] | None
type_checker: Callable[[T, type[T]], bool]
generics: tuple[InternalType[Any], ...]
def rename(key: str) -> Type[Any]:
return Type[Any]().rename(key)
def values(values: Iterable[T]) -> Type[T]:
return Type[T]().values(values)
def default(f: Callable[[], T]) -> Type[T]:
return Type[T]().default(f)
def type_checker(f: Callable[[T, type], bool]) -> Type[T]:
return Type[T]().type_checker(f)
class ModelError(Exception):
pass
class MissingRequiredKey(ModelError):
pass
class InvalidType(ModelError):
pass
class ValidatorFailed(ModelError):
pass
def is_model_subclass(ty: Any) -> TypeGuard[type[Model]]:
return isinstance(ty, type) and issubclass(ty, Model)
def check_types(cls: type, key: str, ty: InternalType[Any], value: Any) -> Any:
if is_model_subclass(ty.type):
return ty.type(value)
if not ty.type_checker(value, ty.type):
raise InvalidType(f"Key {key} expected {ty.type} but found {type(value)}")
if ty.type in [list, set, tuple]:
inner_ty = ty.generics[0]
new_value: list[Any] = []
for inner_value in value:
if is_model_subclass(inner_ty):
new_value.append(inner_ty(inner_value))
else:
assert isinstance(inner_ty, InternalType)
check_types(cls, key, inner_ty, inner_value)
new_value.append(inner_value)
return ty.type(new_value) # cast back to the proper type because we are using list to store the values
elif ty is dict:
key_ty, value_ty = get_args(ty)
new_dict: dict[Any, Any] = {}
for inner_key, inner_value in value.items():
new_key = check_types(cls, f"{key}.{inner_key}", key_ty, inner_key)
new_value = check_types(cls, f"{key}.{inner_key}", value_ty, value_ty)
new_dict[new_key] = new_value
return new_dict
return value
def convert_to_type(cls: type, key: Any, value: Any) -> InternalType[Any]:
if isinstance(value, InternalType):
return value
value = evaluate_annotation(cls, value)
origin = get_origin(value)
args = get_args(value)
if origin is Annotated:
original_ty: Type[Any] = args[1]
ty = original_ty._to_internal()
ty.type = args[0]
elif origin in (Union, UnionType):
if len(args) == 2 and args[1] is NoneType:
ty = Type[Any]()._to_internal()
ty.default = lambda: None
ty.type = (args[0], NoneType)
else:
raise ModelError("Union is not supported")
elif origin in [list, set, tuple]:
original_internal_ty = args[0]
internal_ty = convert_to_type(cls, key, original_internal_ty)
ty = Type[Any]()._to_internal()
ty.type = value
ty.type_checker = lambda obj, type: isinstance(obj, get_origin(type)) and all(internal_ty.type_checker(v, internal_ty.type) for v in obj)
ty.generics = (internal_ty,)
elif origin is dict:
original_key_internal_ty, original_value_internal_ty = args
internal_key_ty = convert_to_type(cls, key, original_key_internal_ty)
internal_value_ty = convert_to_type(cls, key, original_value_internal_ty)
ty = Type[Any]()._to_internal()
ty.type = value
ty.type_checker = lambda obj, type: isinstance(obj, dict) and all(internal_key_ty.type_checker(k, internal_key_ty.type) and internal_value_ty.type_checker(v, internal_value_ty.type) for k, v in obj.items())
ty.generics = (internal_key_ty, internal_value_ty)
else:
ty = Type[Any]()._to_internal()
ty.type = value
ty.key = key
return ty
class Model:
_items: dict[str, InternalType[Any]]
def __init_subclass__(cls) -> None:
items: dict[str, InternalType[Any]] = {}
for key, value in cls.__annotations__.items():
ty = convert_to_type(cls, key, value)
items[ty.rename or ty.key] = ty
cls._items = items
def __init__(self, data: dict[str, Any]):
rebuilt: dict[str, Any] = {}
for key, ty in self._items.items():
if key not in data:
if default := ty.default:
value = default()
else:
raise MissingRequiredKey(f"Missing required key `{ty.key}`")
else:
value = data[key]
value = check_types(self.__class__, ty.key, ty, value)
if (ty.literals and value not in ty.literals) or not ty.validator(value):
raise ValidatorFailed(f"Key {ty.key}'s validator failed")
rebuilt[ty.key] = value
for key, value in rebuilt.items():
setattr(self, key, value)
def to_dict(self) -> dict[str, Any]:
output: dict[str, Any] = {}
for renamed, item in self._items.items():
value = getattr(self, item.key)
if isinstance(value, Model):
value = value.to_dict()
output[renamed] = value
return output
def __repr__(self) -> str:
items: list[str] = []
for item in self._items.values():
value = getattr(self, item.key)
items.append(f"{item.key}={value!r}")
return f"<{self.__class__.__name__} {' '.join(items)}>"
class Foo(Model):
x: Annotated[int, values(range(0, 10))]
y: Annotated[str, rename("z")]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment