Created
June 14, 2019 10:08
-
-
Save dmontagu/8214b6395d5346e55b365758e12faf1a to your computer and use it in GitHub Desktop.
pydantic generics
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Dict, Generic, Optional, Tuple, Type, TypeVar, Union, get_type_hints, no_type_check | |
from pydantic import BaseModel, create_model, validator | |
from pydantic.class_validators import gather_validators | |
class BaseGenericModel: | |
def __init__(self, *args: Any, **kwargs: Any) -> None: | |
pass | |
@no_type_check | |
def __class_getitem__(cls, params: Union[Any, Tuple[Any, ...]]) -> Type[BaseModel]: | |
if not isinstance(params, tuple): | |
params = (params,) | |
generic_alias = super().__class_getitem__(params) | |
__parameters__ = generic_alias.__origin__.__parameters__ | |
typevars_map = dict(zip(__parameters__, params)) | |
model_name = f"{cls.__name__}[{generic_name(params)}]" | |
__module__ = cls.__module__ | |
__config__ = None | |
if hasattr(cls, "Config"): | |
__config__ = cls.Config | |
validators = gather_validators(cls) | |
type_hints = get_type_hints(cls) | |
new_type_hints = resolve_generic_type_hints(type_hints, typevars_map) | |
fields = {} | |
for k, v in new_type_hints.items(): | |
fields[k] = (v, getattr(cls, k, ...)) | |
created_model = create_model( | |
model_name=model_name, __module__=__module__, __config__=__config__, __validators__=validators, **fields | |
) | |
return created_model | |
def generic_name(params: Union[Type[Any], Tuple[Type[Any], ...]]) -> str: | |
param_names = [] | |
for param in params: | |
if hasattr(param, "__name__"): | |
param_names.append(param.__name__) | |
elif hasattr(param, "__origin__"): | |
param_names.append(str(param)) | |
else: | |
raise ValueError(f"Couldn't get name for {param}") | |
return ", ".join(param_names) | |
def resolve_generic_type_hints(generic_type_hints: Dict[str, Any], typevars_map: Dict[Any, Any]) -> Dict[str, Any]: | |
new_type_hints = {} | |
for k, v in generic_type_hints.items(): | |
new_type_hints[k] = resolve_type_hint(v, typevars_map) | |
return new_type_hints | |
def resolve_type_hint(type_: Any, typevars_map: Dict[Any, Any]) -> Type[Any]: | |
if hasattr(type_, "__origin__"): | |
new_args = tuple(resolve_type_hint(x, typevars_map) for x in type_.__args__) | |
return type_.__origin__[new_args] | |
return typevars_map.get(type_, type_) | |
# ################# | |
# ##### Usage ##### | |
# ################# | |
DataT = TypeVar("DataT") | |
ErrorT = TypeVar("ErrorT") | |
class Result(BaseGenericModel, Generic[DataT, ErrorT]): | |
data: Optional[DataT] | |
error: Optional[ErrorT] | |
positive_number: int | |
@validator("error", always=True) | |
def validate_error(cls, v: Optional[ErrorT], values: Dict[str, Any]) -> Optional[ErrorT]: | |
if values.get("data", None) is None and v is None: | |
raise ValueError("Must provide data or error") | |
if values.get("data", None) is not None and v is not None: | |
raise ValueError("Must not provide both data and error") | |
return v | |
@validator("positive_number") | |
def validate_positive_number(cls, v: int) -> int: | |
if v < 0: | |
raise ValueError | |
return v | |
class Error(BaseModel): | |
message: str | |
class Data(BaseModel): | |
number: int | |
text: str | |
success1 = Result[Data, Error](data=Data(number=1, text="a"), positive_number=1) | |
success2 = Result[Data, Error](error=Error(message="error"), positive_number=1) | |
print(success1) | |
print(success1.dict()) | |
print("\n-----\n") | |
print(success2) | |
print(success2.dict()) | |
print("\n-----\n") | |
try: | |
Result[Data, Error](error=Error(message="error"), positive_number=-1) # error -- positive number | |
except Exception as e: | |
print(e) | |
print("\n-----\n") | |
try: | |
Result[Data, Error](data=Data(number=1, text="a"), error=Error(message="error"), positive_number=1) # error both | |
except Exception as e: | |
print(e) | |
print("\n-----\n") | |
try: | |
Result[Data, Error](data=Data(number=1, text="a"), error=Error(message="error"), positive_number=1) # error neither | |
except Exception as e: | |
print(e) | |
""" | |
Prints: | |
Result[Data, Error] data=<Data number=1 text='a'> error=None positive_number=1 | |
{'data': {'number': 1, 'text': 'a'}, 'error': None, 'positive_number': 1} | |
----- | |
Result[Data, Error] data=None error=<Error message='error'> positive_number=1 | |
{'data': None, 'error': {'message': 'error'}, 'positive_number': 1} | |
----- | |
1 validation error | |
positive_number | |
(type=value_error) | |
----- | |
2 validation errors | |
error | |
Must not provide both data and error (type=value_error) | |
error | |
value is not none (type=type_error.none.allowed) | |
----- | |
2 validation errors | |
error | |
Must not provide both data and error (type=value_error) | |
error | |
value is not none (type=type_error.none.allowed) | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment