Skip to content

Instantly share code, notes, and snippets.

@mattwthompson
Created June 5, 2024 21:55
Show Gist options
  • Save mattwthompson/8375acaaf7e1eedaf1dc6f889cf1df9a to your computer and use it in GitHub Desktop.
Save mattwthompson/8375acaaf7e1eedaf1dc6f889cf1df9a to your computer and use it in GitHub Desktop.
pydantic v2 + Pint
"""
Simple, self-contained example showing how to get Pydantic and Pint to play nicely together.
This uses Pydantic v2 (only).
This doesn't yet handle arrays or OpenMM objects, but should be tractable and is called out below.
No (explicit) effort is made to ensure floats or ints remain the same type and don't get ducked.
"""
from pint import UnitRegistry, Quantity
from typing import Annotated
from pydantic import (
WrapValidator,
BaseModel,
ConfigDict,
Field,
WrapSerializer,
ValidatorFunctionWrapHandler,
ValidationInfo,
)
import json
unit = UnitRegistry()
def quantity_validator(
value: str | Quantity | dict,
handler: ValidatorFunctionWrapHandler,
info: ValidationInfo,
) -> Quantity:
"""Take Quantity-like objects and convert them to Quantity objects."""
if info.mode == "json":
if isinstance(value, str):
value = json.loads(value)
# this is coupled to how a Quantity looks in JSON
return Quantity(value["value"], value["unit"])
# some more work is needed with arrays, lists, tuples, etc.
assert info.mode == "python"
if isinstance(value, Quantity):
return value
elif isinstance(value, str):
return unit.Quantity(value)
elif isinstance(value, dict):
return Quantity(value["value"], value["unit"])
# here is where special cases, like for OpenMM, would go
else:
raise ValueError(f"Invalid type {type(value)} for Quantity")
def quantity_json_serializer(
quantity: Quantity,
nxt,
) -> dict:
# Some more work is needed to make arrays play nicely, i.e. not simply doing Quantity.m
return {
"value": quantity.m,
"unit": str(quantity.units),
}
# Pydantic v2 likes to marry validators and serializers to types with Annotated
# https://docs.pydantic.dev/latest/concepts/validators/#annotated-validators
_Quantity = Annotated[
Quantity,
WrapValidator(quantity_validator),
WrapSerializer(quantity_json_serializer),
]
# Simple model using Quantity
class Person(BaseModel):
# Pydantic thinks anything non-stdlib is "abritrary"
model_config = ConfigDict(arbitrary_types_allowed=True)
# NOTE that the model must be defined with our Annotated type
mass: _Quantity = Field()
# Model nesting another model which uses Quantity
class Roster(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
# the Quantity buried in Person.mass is magically handled
people: dict[str, Person] = Field(dict())
foo: _Quantity = Field()
susie = Person(mass="70 kg")
bob = Person(mass="100 kg")
assert susie.mass / bob.mass == 0.7
roster = Roster(people={"Susie": susie, "Bob": bob}, foo=unit.Quantity(1.008, "amu"))
# both dict (model_validate) and JSON (model_validate_json) roundtrips work
for person in [bob, susie]:
for roundtripped in [
Person.model_validate(person),
Person.model_validate(person.model_dump()),
Person.model_validate_json(person.model_dump_json()),
]:
assert str(roundtripped.mass.units) == "kilogram"
assert roundtripped.mass.m in (70, 100)
for roundtripped in [
Roster.model_validate(roster),
Roster.model_validate(roster.model_dump()),
Roster.model_validate_json(roster.model_dump_json()),
]:
assert len(roundtripped.people) == 2
assert roundtripped.foo == unit.Quantity(1.008, "amu")
assert roundtripped.people["Susie"].mass == unit.Quantity(70, "kilogram")
assert roundtripped.people["Bob"].mass == unit.Quantity(100, "kilogram")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment