Created
June 5, 2024 21:55
-
-
Save mattwthompson/8375acaaf7e1eedaf1dc6f889cf1df9a to your computer and use it in GitHub Desktop.
pydantic v2 + Pint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Simple, self-contained example showing how to get Pydantic and Pint to play nicely together. | |
This uses Pydantic v2 (only). | |
This doesn't yet handle arrays or OpenMM objects, but should be tractable and is called out below. | |
No (explicit) effort is made to ensure floats or ints remain the same type and don't get ducked. | |
""" | |
from pint import UnitRegistry, Quantity | |
from typing import Annotated | |
from pydantic import ( | |
WrapValidator, | |
BaseModel, | |
ConfigDict, | |
Field, | |
WrapSerializer, | |
ValidatorFunctionWrapHandler, | |
ValidationInfo, | |
) | |
import json | |
unit = UnitRegistry() | |
def quantity_validator( | |
value: str | Quantity | dict, | |
handler: ValidatorFunctionWrapHandler, | |
info: ValidationInfo, | |
) -> Quantity: | |
"""Take Quantity-like objects and convert them to Quantity objects.""" | |
if info.mode == "json": | |
if isinstance(value, str): | |
value = json.loads(value) | |
# this is coupled to how a Quantity looks in JSON | |
return Quantity(value["value"], value["unit"]) | |
# some more work is needed with arrays, lists, tuples, etc. | |
assert info.mode == "python" | |
if isinstance(value, Quantity): | |
return value | |
elif isinstance(value, str): | |
return unit.Quantity(value) | |
elif isinstance(value, dict): | |
return Quantity(value["value"], value["unit"]) | |
# here is where special cases, like for OpenMM, would go | |
else: | |
raise ValueError(f"Invalid type {type(value)} for Quantity") | |
def quantity_json_serializer( | |
quantity: Quantity, | |
nxt, | |
) -> dict: | |
# Some more work is needed to make arrays play nicely, i.e. not simply doing Quantity.m | |
return { | |
"value": quantity.m, | |
"unit": str(quantity.units), | |
} | |
# Pydantic v2 likes to marry validators and serializers to types with Annotated | |
# https://docs.pydantic.dev/latest/concepts/validators/#annotated-validators | |
_Quantity = Annotated[ | |
Quantity, | |
WrapValidator(quantity_validator), | |
WrapSerializer(quantity_json_serializer), | |
] | |
# Simple model using Quantity | |
class Person(BaseModel): | |
# Pydantic thinks anything non-stdlib is "abritrary" | |
model_config = ConfigDict(arbitrary_types_allowed=True) | |
# NOTE that the model must be defined with our Annotated type | |
mass: _Quantity = Field() | |
# Model nesting another model which uses Quantity | |
class Roster(BaseModel): | |
model_config = ConfigDict(arbitrary_types_allowed=True) | |
# the Quantity buried in Person.mass is magically handled | |
people: dict[str, Person] = Field(dict()) | |
foo: _Quantity = Field() | |
susie = Person(mass="70 kg") | |
bob = Person(mass="100 kg") | |
assert susie.mass / bob.mass == 0.7 | |
roster = Roster(people={"Susie": susie, "Bob": bob}, foo=unit.Quantity(1.008, "amu")) | |
# both dict (model_validate) and JSON (model_validate_json) roundtrips work | |
for person in [bob, susie]: | |
for roundtripped in [ | |
Person.model_validate(person), | |
Person.model_validate(person.model_dump()), | |
Person.model_validate_json(person.model_dump_json()), | |
]: | |
assert str(roundtripped.mass.units) == "kilogram" | |
assert roundtripped.mass.m in (70, 100) | |
for roundtripped in [ | |
Roster.model_validate(roster), | |
Roster.model_validate(roster.model_dump()), | |
Roster.model_validate_json(roster.model_dump_json()), | |
]: | |
assert len(roundtripped.people) == 2 | |
assert roundtripped.foo == unit.Quantity(1.008, "amu") | |
assert roundtripped.people["Susie"].mass == unit.Quantity(70, "kilogram") | |
assert roundtripped.people["Bob"].mass == unit.Quantity(100, "kilogram") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment