Skip to content

Instantly share code, notes, and snippets.

@bendichter
Created January 8, 2024 20:04
Show Gist options
  • Save bendichter/257051674f8d6ea97fbb732f778dda35 to your computer and use it in GitHub Desktop.
Save bendichter/257051674f8d6ea97fbb732f778dda35 to your computer and use it in GitHub Desktop.
constrainedarray
from typing import Tuple, Type, Any, Union, Optional
from pydantic import BaseModel
import numpy as np
def get_shape(data):
"""
Get the shape of various data structures: NumPy array, h5py dataset, Zarr dataset, or a nested list of arbitrary depth.
Parameters
----------
data : {np.ndarray, h5py.Dataset, zarr.core.Array, list}
The data structure whose shape is to be determined.
Returns
-------
tuple
A tuple representing the shape of the data structure.
Raises
------
TypeError
If the data type is not supported.
Examples
--------
>>> import numpy as np
>>> get_shape(np.array([[1, 2], [3, 4]]))
(2, 2)
>>> # For h5py and zarr datasets, ensure you have these datasets created in your environment.
>>> # Example for a nested list:
>>> get_shape([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
(2, 2, 2)
"""
if hasattr(data, "shape"):
return data.shape
elif isinstance(data, list):
return _get_list_shape(data)
else:
raise TypeError("Unsupported data type for getting shape.")
def _get_list_shape(lst):
"""
Get the shape of a nested list of arbitrary depth.
Parameters
----------
lst : list
The nested list whose shape is to be determined.
Returns
-------
tuple
A tuple representing the shape of the nested list.
Notes
-----
This function assumes that the nested list is regular, i.e., each sub-list at each level has the same length.
"""
if not lst:
return ()
if isinstance(lst[0], list):
return (len(lst),) + _get_list_shape(lst[0])
else:
return (len(lst),)
class ConstrainedArray:
"""
A custom Pydantic type for validating NumPy arrays with specific shapes and data types.
Attributes
----------
shape : Tuple[Union[Tuple[Union[int, None], ...], Union[int, None]]]
A list of expected shapes for the NumPy array, where `None` in any position within a shape allows any size for that dimension.
dtype : Union[Type, Tuple[Type, ...]]
The expected or allowed data type(s) for the NumPy array's elements. Can be a single data type or a tuple of data types.
Examples
--------
>>> class Model(BaseModel):
... array: ConstrainedArray(shape=[(2, 2), (3, 3)], dtype=(np.float32, np.float64))
...
>>> model = Model(array=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
>>> print(model)
"""
def __init__(
self,
shapes: Optional[Tuple[Tuple[Union[int, None], ...]]] = None,
dtype: Optional[Union[Type, Tuple[Type, ...]]] = None,
):
self.shapes = shapes
self.dtype = dtype
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, value: Any) -> Any:
# Check if the array's dtype is in the allowed dtype(s)
if cls.dtype is not None:
if isinstance(cls.dtype, tuple):
if not any(np.issubdtype(value.dtype, dt) for dt in cls.dtype):
allowed_dtypes = ', '.join(dt.__name__ for dt in cls.dtype)
raise TypeError(f'Expected array with dtype(s) {allowed_dtypes}, got {value.dtype.name}')
else:
if not np.issubdtype(value.dtype, cls.dtype):
raise TypeError(f'Expected array with dtype {cls.dtype.__name__}, got {value.dtype.name}')
# Check if the array's shape matches any of the allowed shapes
if cls.shapes is not None:
this_shape = get_data_shape(value)
if cls.shapes[0] is None or isinstance(cls.shapes[0], int):
if not cls._check_single_shape(this_shape, cls.shapes):
raise ValueError(f'he provided array shape: {value.shape} does not match the expected shape: {cls.shapes}')
elif not cls._check_multi_shape(this_shape, cls.shapes):
allowed_shapes = ', '.join(str(s) for s in cls.shapes)
raise ValueError(f'None of the expected shapes {allowed_shapes} match the provided array shape {value.shape}')
return value
@classmethod
def _check_multi_shape(cls, shape: Tuple[int, ...], list_of_allowable_shapes: Tuple[Tuple[Union[int, None], ...]]):
return any(cls._check_single_shape(shape, allowed_shape) for allowed_shape in list_of_allowable_shapes)
@classmethod
def _check_single_shape(cls, shape: Tuple[int, ...], allowed_shape: Tuple[Union[int, None], ...]):
return all(s2 is None or s1 == s2 for s1, s2 in zip(shape, allowed_shape))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment