-
-
Save mattwthompson/3a7d9959d20d7bbb10b06c894c5e5acc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
In: | |
``` | |
from typing import Any, Dict, Type, TypeVar, Union | |
import numpy as np | |
from pint import Quantity, UnitRegistry | |
from pydantic import BaseModel | |
TYPE_CHECKING = False | |
u = UnitRegistry() | |
class _ArrayMeta(type): | |
def __getitem__(self, t): | |
return type("BaseArray", (BaseArray,), {"__dtype__": t}) | |
if TYPE_CHECKING: | |
BaseArray = np.ndarray | |
else: | |
class BaseArray(np.ndarray, metaclass=_ArrayMeta): | |
""" | |
TODO: | |
* JSON encoder | |
* Can .base_unit be protected? | |
* Should this fundamentall by np.ndarray or u.Quantity? | |
""" | |
base_unit = "not implemented" | |
@classmethod | |
def __get_validators__(cls): | |
yield cls.validate | |
@classmethod | |
def validate(cls, v: Union[int, str, np.ndarray, u.Quantity]) -> np.ndarray: | |
if isinstance(v, (int, str)): | |
raise TypeError("not implemented") | |
# If it's a list, cast into array before asking its __dtype__ | |
if isinstance(v, (list)): | |
v = np.asarray(v) | |
dtype = getattr(cls, "__dtype__", None) | |
if isinstance(dtype, tuple): | |
dtype, shape = dtype | |
else: | |
shape = tuple() | |
if isinstance(v, Quantity): | |
q = v.to(cls.base_unit) | |
# return cls(q.m) | |
tmp = q.m | |
elif isinstance(v, np.ndarray): | |
q = u.Quantity(v, cls.base_unit) | |
# return cls(q.m) | |
tmp = q.m | |
try: | |
result = np.array(tmp, dtype=dtype, copy=False, ndmin=len(shape)) | |
if len(shape): | |
result = result.reshape(shape) | |
return u.Quantity(result, cls.base_unit) | |
# m = np.asarray(tmp, dtype=dtype) | |
except ValueError: | |
raise ValueError("Could not cast {} to NumPy Array!".format(v)) | |
# return cls(m) | |
@classmethod | |
def __repr__(cls): | |
return str(cls) + " " + cls.base_unit | |
class LengthArray(BaseArray): | |
base_unit = "angstrom" | |
class MassArray(BaseArray): | |
base_unit = "dalton" | |
class Model(BaseModel): | |
p1: LengthArray | |
p2: LengthArray | |
p3: LengthArray | |
p4: MassArray | |
m = Model( | |
p1=[0, 4.0], | |
p2=np.array([1, 2, 3, 4], dtype=np.float64), | |
p3=u.Quantity([-1, 1], "nanometer"), | |
p4=[12, 1, 1, 1, 1], | |
) | |
print(m.p1) | |
print(m.p2) | |
print(m.p3) | |
print(m.p4) | |
``` | |
Out: | |
``` | |
[0.0 4.0] angstrom | |
[1.0 2.0 3.0 4.0] angstrom | |
[-10.0 10.0] angstrom | |
[12 1 1 1 1] dalton | |
``` |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment