Skip to content

Instantly share code, notes, and snippets.

@dmontagu
Created June 14, 2019 10:08
Show Gist options
  • Save dmontagu/8214b6395d5346e55b365758e12faf1a to your computer and use it in GitHub Desktop.
Save dmontagu/8214b6395d5346e55b365758e12faf1a to your computer and use it in GitHub Desktop.
pydantic generics
from typing import Any, Dict, Generic, Optional, Tuple, Type, TypeVar, Union, get_type_hints, no_type_check
from pydantic import BaseModel, create_model, validator
from pydantic.class_validators import gather_validators
class BaseGenericModel:
def __init__(self, *args: Any, **kwargs: Any) -> None:
pass
@no_type_check
def __class_getitem__(cls, params: Union[Any, Tuple[Any, ...]]) -> Type[BaseModel]:
if not isinstance(params, tuple):
params = (params,)
generic_alias = super().__class_getitem__(params)
__parameters__ = generic_alias.__origin__.__parameters__
typevars_map = dict(zip(__parameters__, params))
model_name = f"{cls.__name__}[{generic_name(params)}]"
__module__ = cls.__module__
__config__ = None
if hasattr(cls, "Config"):
__config__ = cls.Config
validators = gather_validators(cls)
type_hints = get_type_hints(cls)
new_type_hints = resolve_generic_type_hints(type_hints, typevars_map)
fields = {}
for k, v in new_type_hints.items():
fields[k] = (v, getattr(cls, k, ...))
created_model = create_model(
model_name=model_name, __module__=__module__, __config__=__config__, __validators__=validators, **fields
)
return created_model
def generic_name(params: Union[Type[Any], Tuple[Type[Any], ...]]) -> str:
param_names = []
for param in params:
if hasattr(param, "__name__"):
param_names.append(param.__name__)
elif hasattr(param, "__origin__"):
param_names.append(str(param))
else:
raise ValueError(f"Couldn't get name for {param}")
return ", ".join(param_names)
def resolve_generic_type_hints(generic_type_hints: Dict[str, Any], typevars_map: Dict[Any, Any]) -> Dict[str, Any]:
new_type_hints = {}
for k, v in generic_type_hints.items():
new_type_hints[k] = resolve_type_hint(v, typevars_map)
return new_type_hints
def resolve_type_hint(type_: Any, typevars_map: Dict[Any, Any]) -> Type[Any]:
if hasattr(type_, "__origin__"):
new_args = tuple(resolve_type_hint(x, typevars_map) for x in type_.__args__)
return type_.__origin__[new_args]
return typevars_map.get(type_, type_)
# #################
# ##### Usage #####
# #################
DataT = TypeVar("DataT")
ErrorT = TypeVar("ErrorT")
class Result(BaseGenericModel, Generic[DataT, ErrorT]):
data: Optional[DataT]
error: Optional[ErrorT]
positive_number: int
@validator("error", always=True)
def validate_error(cls, v: Optional[ErrorT], values: Dict[str, Any]) -> Optional[ErrorT]:
if values.get("data", None) is None and v is None:
raise ValueError("Must provide data or error")
if values.get("data", None) is not None and v is not None:
raise ValueError("Must not provide both data and error")
return v
@validator("positive_number")
def validate_positive_number(cls, v: int) -> int:
if v < 0:
raise ValueError
return v
class Error(BaseModel):
message: str
class Data(BaseModel):
number: int
text: str
success1 = Result[Data, Error](data=Data(number=1, text="a"), positive_number=1)
success2 = Result[Data, Error](error=Error(message="error"), positive_number=1)
print(success1)
print(success1.dict())
print("\n-----\n")
print(success2)
print(success2.dict())
print("\n-----\n")
try:
Result[Data, Error](error=Error(message="error"), positive_number=-1) # error -- positive number
except Exception as e:
print(e)
print("\n-----\n")
try:
Result[Data, Error](data=Data(number=1, text="a"), error=Error(message="error"), positive_number=1) # error both
except Exception as e:
print(e)
print("\n-----\n")
try:
Result[Data, Error](data=Data(number=1, text="a"), error=Error(message="error"), positive_number=1) # error neither
except Exception as e:
print(e)
"""
Prints:
Result[Data, Error] data=<Data number=1 text='a'> error=None positive_number=1
{'data': {'number': 1, 'text': 'a'}, 'error': None, 'positive_number': 1}
-----
Result[Data, Error] data=None error=<Error message='error'> positive_number=1
{'data': None, 'error': {'message': 'error'}, 'positive_number': 1}
-----
1 validation error
positive_number
(type=value_error)
-----
2 validation errors
error
Must not provide both data and error (type=value_error)
error
value is not none (type=type_error.none.allowed)
-----
2 validation errors
error
Must not provide both data and error (type=value_error)
error
value is not none (type=type_error.none.allowed)
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment