Skip to content

Instantly share code, notes, and snippets.

@dmontagu
Last active August 7, 2019 09:23
Show Gist options
  • Save dmontagu/381d03126c3b35d58274798a19fc12fa to your computer and use it in GitHub Desktop.
Save dmontagu/381d03126c3b35d58274798a19fc12fa to your computer and use it in GitHub Desktop.
Pydantic mypy plugin for signature checking
[mypy]
plugins = pydanticmypy.py
follow_imports = silent
strict_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
disallow_any_generics = True
check_untyped_defs = True
ignore_missing_imports = True
disallow_untyped_defs = True
from collections import OrderedDict
from typing import Callable, Dict, List, Set, TypeVar
from mypy.nodes import (
ARG_NAMED,
ARG_NAMED_OPT,
ARG_STAR2,
MDEF,
Any,
Argument,
AssignmentStmt,
CallExpr,
ClassDef,
EllipsisExpr,
JsonDict,
MemberExpr,
NameExpr,
PlaceholderNode,
StrExpr,
SymbolTableNode,
TempNode,
TypeInfo,
Var,
)
from mypy.plugin import ClassDefContext, FunctionContext, Plugin
from mypy.plugins.common import add_method
from mypy.server.trigger import make_wildcard_trigger
from mypy.types import AnyType, NoneTyp, Optional, Type, TypeOfAny, UnionType
MYPY = False # we should support Python 3.5.1 and cases where typing_extensions is not available.
if MYPY:
from typing_extensions import Type as TypingType
T = TypeVar("T")
CB = Optional[Callable[[T], None]]
BASEMODEL_NAME = "pydantic.main.BaseModel"
class BasicPydanticPlugin(Plugin):
"""
TODO: Add docstring describing capabilities
"""
def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]:
# TODO: handle create_model in here
return None
def get_base_class_hook(self, fullname: str) -> "CB[ClassDefContext]":
# This function is called on any class that is the base class for another class
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
if is_model(sym.node):
return pydantic_model_class_maker_callback
return None
def is_model(info: TypeInfo) -> bool:
for base in info.mro:
if base.fullname() == BASEMODEL_NAME:
return True
return False
class PydanticModelField:
def __init__(self, name: str, has_default: bool, line: int, column: int) -> None:
self.name = name
self.has_default = has_default
self.line = line
self.column = column
def to_argument(self, info: TypeInfo) -> Argument:
return Argument(
variable=self.to_var(info),
type_annotation=info[self.name].type,
initializer=None,
kind=ARG_NAMED_OPT if self.has_default else ARG_NAMED,
)
def to_var(self, info: TypeInfo) -> Var:
return Var(self.name, info[self.name].type)
def serialize(self) -> JsonDict:
return {"name": self.name, "has_default": self.has_default, "line": self.line, "column": self.column}
@classmethod
def deserialize(cls, info: TypeInfo, data: JsonDict) -> "PydanticModelField":
return cls(**data)
class PydanticModelTransformer:
def __init__(self, ctx: ClassDefContext) -> None:
self._ctx = ctx
def add_basemodel_init(self, attributes: List[PydanticModelField], config: Dict[str, Any]):
ctx = self._ctx
init_arguments = [attribute.to_argument(ctx.cls.info) for attribute in attributes]
if config.get("extra") is not False:
var = Var("kwargs")
init_arguments.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2))
add_method(ctx, "__init__", init_arguments, NoneTyp())
def transform(self) -> None:
"""Apply all the necessary transformations to the underlying
dataclass so as to ensure it is fully type checked according
to the rules in PEP 557.
"""
ctx = self._ctx
info = self._ctx.cls.info
attributes = self.collect_attributes()
config = self.collect_config()
if ctx.api.options.new_semantic_analyzer:
# Check if attribute types are ready.
for attr in attributes:
if info[attr.name].type is None:
# TODO: Figure out why this is necessary
if not ctx.api.final_iteration:
ctx.api.defer()
return
self.add_basemodel_init(attributes, config)
if config.get("allow_mutation") is False:
self._freeze(attributes)
info.metadata["pydanticmodel"] = {
"attributes": OrderedDict((attr.name, attr.serialize()) for attr in attributes),
"config": config,
}
def collect_config(self) -> Dict[str, Any]:
ctx = self._ctx
cls = self._ctx.cls
config_fields = ("extra", "allow_mutation", "use_enum_values", "arbitrary_types_allowed", "orm_mode")
config = {}
for stmt in cls.defs.body:
if not isinstance(stmt, ClassDef):
continue
for substmt in stmt.defs.body:
if not isinstance(substmt, AssignmentStmt):
continue
lhs = substmt.lvalues[0]
if not isinstance(lhs, NameExpr):
continue
if lhs.name not in config_fields:
continue
if lhs.name == "extra":
if isinstance(substmt.rvalue, StrExpr):
config[lhs.name] = substmt.rvalue.value != "forbid"
elif isinstance(substmt.rvalue, MemberExpr):
config[lhs.name] = substmt.rvalue.name != "forbid"
continue
if substmt.rvalue.fullname in ("builtins.True", "builtins.False"):
config[lhs.name] = substmt.rvalue.fullname == "builtins.True"
for info in cls.info.mro[1:-1]: # 0 is the current class, -1 is object
if "pydanticmodel" not in info.metadata:
continue
# Each class depends on the set of attributes in its dataclass ancestors.
ctx.api.add_plugin_dependency(make_wildcard_trigger(info.fullname()))
for name, value in info.metadata["pydanticmodel"]["config"].items():
if name not in config:
config[name] = value
return config
def collect_attributes(self):
# First, collect attributes belonging to the current class.
ctx = self._ctx
cls = self._ctx.cls
attrs = [] # type: List[PydanticModelField]
known_attrs = set() # type: Set[str]
for stmt in cls.defs.body:
if not isinstance(stmt, AssignmentStmt): # `and stmt.new_syntax` to require annotation
continue
lhs = stmt.lvalues[0]
if not isinstance(lhs, NameExpr):
continue
if lhs.name == "__config__": # BaseConfig not well handled; I'm not sure why yet
continue
sym = cls.info.names.get(lhs.name)
if sym is None:
assert ctx.api.options.new_semantic_analyzer
continue
node = sym.node
if isinstance(node, PlaceholderNode):
# This node is not ready yet.
continue
assert isinstance(node, Var)
# x: ClassVar[int] is ignored by dataclasses.
if node.is_classvar:
continue
has_default = self._get_has_default(cls, lhs, stmt)
known_attrs.add(lhs.name)
attrs.append(PydanticModelField(name=lhs.name, has_default=has_default, line=stmt.line, column=stmt.column))
all_attrs = attrs.copy()
for info in cls.info.mro[1:-2]: # 0 is the current class, -2 is BaseModel, -1 is object
if "pydanticmodel" not in info.metadata:
continue
super_attrs = []
# Each class depends on the set of attributes in its dataclass ancestors.
ctx.api.add_plugin_dependency(make_wildcard_trigger(info.fullname()))
for name, data in info.metadata["pydanticmodel"]["attributes"].items():
if name not in known_attrs:
attr = PydanticModelField.deserialize(info, data)
known_attrs.add(name)
super_attrs.append(attr)
else:
(attr,) = [a for a in all_attrs if a.name == name]
all_attrs.remove(attr)
super_attrs.append(attr)
all_attrs = super_attrs + all_attrs
return all_attrs
def _get_has_default(self, cls, lhs, stmt) -> bool:
if not isinstance(stmt.rvalue, TempNode):
if not isinstance(stmt, AssignmentStmt):
return False
if isinstance(stmt.rvalue, CallExpr):
callee_is_schema = stmt.rvalue.callee.fullname == "pydantic.schema.Schema"
arg_is_ellipsis = len(stmt.rvalue.args) > 0 and type(stmt.rvalue.args[0]) is EllipsisExpr
if callee_is_schema and arg_is_ellipsis:
return False
return True
if type(cls.info[lhs.name].type) is UnionType:
item_types = [type(item) for item in cls.info[lhs.name].type.items]
if NoneTyp in item_types:
# Optional has default of None
return True
return False
def _freeze(self, attributes: List[PydanticModelField]) -> None:
"""Converts all attributes to @property methods in order to
emulate frozen classes.
"""
info = self._ctx.cls.info
for attr in attributes:
sym_node = info.names.get(attr.name)
if sym_node is not None:
var = sym_node.node
assert isinstance(var, Var)
var.is_property = True
else:
var = attr.to_var(info)
var.info = info
var.is_property = True
var._fullname = info.fullname() + "." + var.name()
info.names[var.name()] = SymbolTableNode(MDEF, var)
def pydantic_model_class_maker_callback(ctx: ClassDefContext) -> None:
"""Hooks into the class typechecking process to add support for dataclasses.
"""
transformer = PydanticModelTransformer(ctx)
transformer.transform()
def plugin(version: str) -> "TypingType[Plugin]":
return BasicPydanticPlugin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment