Skip to content

Instantly share code, notes, and snippets.

@mattwthompson
Created November 23, 2020 21:04
Show Gist options
  • Save mattwthompson/3a7d9959d20d7bbb10b06c894c5e5acc to your computer and use it in GitHub Desktop.
Save mattwthompson/3a7d9959d20d7bbb10b06c894c5e5acc to your computer and use it in GitHub Desktop.
In:
```
from typing import Any, Dict, Type, TypeVar, Union
import numpy as np
from pint import Quantity, UnitRegistry
from pydantic import BaseModel
TYPE_CHECKING = False
u = UnitRegistry()
class _ArrayMeta(type):
def __getitem__(self, t):
return type("BaseArray", (BaseArray,), {"__dtype__": t})
if TYPE_CHECKING:
BaseArray = np.ndarray
else:
class BaseArray(np.ndarray, metaclass=_ArrayMeta):
"""
TODO:
* JSON encoder
* Can .base_unit be protected?
* Should this fundamentall by np.ndarray or u.Quantity?
"""
base_unit = "not implemented"
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v: Union[int, str, np.ndarray, u.Quantity]) -> np.ndarray:
if isinstance(v, (int, str)):
raise TypeError("not implemented")
# If it's a list, cast into array before asking its __dtype__
if isinstance(v, (list)):
v = np.asarray(v)
dtype = getattr(cls, "__dtype__", None)
if isinstance(dtype, tuple):
dtype, shape = dtype
else:
shape = tuple()
if isinstance(v, Quantity):
q = v.to(cls.base_unit)
# return cls(q.m)
tmp = q.m
elif isinstance(v, np.ndarray):
q = u.Quantity(v, cls.base_unit)
# return cls(q.m)
tmp = q.m
try:
result = np.array(tmp, dtype=dtype, copy=False, ndmin=len(shape))
if len(shape):
result = result.reshape(shape)
return u.Quantity(result, cls.base_unit)
# m = np.asarray(tmp, dtype=dtype)
except ValueError:
raise ValueError("Could not cast {} to NumPy Array!".format(v))
# return cls(m)
@classmethod
def __repr__(cls):
return str(cls) + " " + cls.base_unit
class LengthArray(BaseArray):
base_unit = "angstrom"
class MassArray(BaseArray):
base_unit = "dalton"
class Model(BaseModel):
p1: LengthArray
p2: LengthArray
p3: LengthArray
p4: MassArray
m = Model(
p1=[0, 4.0],
p2=np.array([1, 2, 3, 4], dtype=np.float64),
p3=u.Quantity([-1, 1], "nanometer"),
p4=[12, 1, 1, 1, 1],
)
print(m.p1)
print(m.p2)
print(m.p3)
print(m.p4)
```
Out:
```
[0.0 4.0] angstrom
[1.0 2.0 3.0 4.0] angstrom
[-10.0 10.0] angstrom
[12 1 1 1 1] dalton
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment