Skip to content

Instantly share code, notes, and snippets.

@kylebarron
Forked from danielhfrank/pydantic.py
Created October 4, 2021 17:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kylebarron/b331a14fdde42769471924ec8c4b93f0 to your computer and use it in GitHub Desktop.
Save kylebarron/b331a14fdde42769471924ec8c4b93f0 to your computer and use it in GitHub Desktop.
Pydantic with Numpy
from typing import Generic, TypeVar
import numpy as np
from pydantic.fields import ModelField
JSON_ENCODERS = {
np.ndarray: lambda arr: arr.tolist()
}
DType = TypeVar('DType')
class TypedArray(np.ndarray, Generic[DType]):
"""Wrapper class for numpy arrays that stores and validates type information.
This can be used in place of a numpy array, but when used in a pydantic BaseModel
or with pydantic.validate_arguments, its dtype will be *coerced* at runtime to the
declared type.
"""
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, val, field: ModelField):
dtype_field = field.sub_fields[0]
actual_dtype = dtype_field.type_.__args__[0]
# If numpy cannot create an array with the request dtype, an error will be raised
# and correctly bubbled up.
np_array = np.array(val, dtype=actual_dtype)
return np_array
from typing_extensions import Literal
import numpy as np
import pydantic
import pytest
from .pydantic import TypedArray, JSON_ENCODERS
class Model(pydantic.BaseModel):
x: TypedArray[Literal['float32']]
class Config:
json_encoders = JSON_ENCODERS
class InvalidModel(pydantic.BaseModel):
x: TypedArray[Literal['asdfasdf']]
def test_array():
model = Model(x=[1, 2])
assert(isinstance(model.x, np.ndarray))
assert(model.x.dtype == np.dtype('float32'))
# but I think this will not work yet
with pytest.raises(pydantic.error_wrappers.ValidationError):
Model(x='asdfa')
def test_invalid():
with pytest.raises(pydantic.error_wrappers.ValidationError):
InvalidModel(x='boom')
def test_serde():
model = Model(x=[1, 2])
assert(model.json() == '{"x": [1.0, 2.0]}')
# Using validate_arguments here will _coerce_ an array into the correct dtype
@pydantic.validate_arguments
def square(arr: TypedArray[Literal['float32']]) -> np.array:
return arr ** 2
def test_validation_decorator():
x = np.array([1, 2, 3], dtype='int32')
y = square(x)
assert(y.dtype == np.dtype('float32'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment