Skip to content

Instantly share code, notes, and snippets.

@dmontagu
Last active January 28, 2020 13:05
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dmontagu/69baaa9459bbfc471136ffd5d38bc6b6 to your computer and use it in GitHub Desktop.
Save dmontagu/69baaa9459bbfc471136ffd5d38bc6b6 to your computer and use it in GitHub Desktop.
Proof of concept tracking numpy array data types and shapes
from typing import TYPE_CHECKING, Any, Generic, Tuple, Type, TypeVar, Union
import numpy as np
from typing_extensions import Literal
T = TypeVar("T")
S = TypeVar("S")
class Dimension(Generic[T]):
if not TYPE_CHECKING:
def __class_getitem__(cls: Type[S], item: Any) -> Type[S]:
return cls
D1 = TypeVar("D1", bound=Dimension[Any])
D2 = TypeVar("D2", bound=Dimension[Any])
D3 = TypeVar("D3", bound=Dimension[Any])
D4 = TypeVar("D4", bound=Dimension[Any])
D5 = TypeVar("D5", bound=Dimension[Any])
if TYPE_CHECKING:
class NPFloat:
pass
class NPInt:
pass
class NPUInt8:
pass
else:
NPInt = np.int
NPUInt8 = np.uint8
NPFloat = np.float
DType = TypeVar("DType", NPInt, NPUInt8, NPFloat, bool)
if TYPE_CHECKING:
class _TypedArray(Generic[DType]):
if not TYPE_CHECKING:
def __class_getitem__(cls: Type[T], item: Any) -> Type[T]:
return cls
def __init__(
self,
shape: Any,
dtype: Type[DType] = None,
buffer: Any = None,
offset: int = 0,
strides: Any = None,
order: Any = None,
) -> None:
# super().__init__(shape, dtype, buffer, offset, strides, order)
pass
def __iter__(self) -> Any:
pass
@property
def shape(self) -> Tuple[int, ...]:
...
def __getitem__(self, item: Any) -> Any:
pass
def __getattr__(self, item: Any) -> Any:
pass
def __setattr__(self, item: Any, value: Any) -> None:
pass
else:
_TypedArray = np.ndarray
Zero = Dimension[Literal[0]]
One = Dimension[Literal[1]]
Two = Dimension[Literal[2]]
Three = Dimension[Literal[3]]
N = Dimension[int]
M = Dimension[int]
class Array1D(_TypedArray, Generic[DType, D1]): # type: ignore[type-arg]
if TYPE_CHECKING:
def __ge__(self, other: Union["_TypedArray[Any]", float]) -> "Array1D[bool, D1]":
...
def __le__(self, other: Union["_TypedArray[Any]", float]) -> "Array1D[bool, D1]":
...
def __gt__(self, other: Union["_TypedArray[Any]", float]) -> "Array1D[bool, D1]":
...
def __lt__(self, other: Union["_TypedArray[Any]", float]) -> "Array1D[bool, D1]":
...
class Array2D(_TypedArray, Generic[DType, D1, D2]): # type: ignore[type-arg]
if TYPE_CHECKING:
def __ge__(self, other: Union["_TypedArray[Any]", float]) -> "Array2D[bool, D1, D2]":
...
def __le__(self, other: Union["_TypedArray[Any]", float]) -> "Array2D[bool, D1, D2]":
...
def __gt__(self, other: Union["_TypedArray[Any]", float]) -> "Array2D[bool, D1, D2]":
...
def __lt__(self, other: Union["_TypedArray[Any]", float]) -> "Array2D[bool, D1, D2]":
...
class Array3D(_TypedArray, Generic[DType, D1, D2, D3]): # type: ignore[type-arg]
pass
class Array4D(_TypedArray, Generic[DType, D1, D2, D3, D4]): # type: ignore[type-arg]
pass
class Array5D(_TypedArray, Generic[DType, D1, D2, D3, D4, D5]): # type: ignore[type-arg]
pass
AnyArray1D = Array1D[Any, Any]
AnyArray2D = Array2D[Any, Any, Any]
AnyArray3D = Array3D[Any, Any, Any, Any]
AnyArray4D = Array4D[Any, Any, Any, Any, Any]
AnyArray5D = Array5D[Any, Any, Any, Any, Any, Any]
"""
# Usage example:
def f(x: Array2D[NPUInt8, D1, D2]) -> Array1D[NPFloat, D1]:
y = x[0] / 255
return y
def g(x: Array1D[NPFloat, D1]) -> Array2D[NPFloat, One, D1]:
return x[np.newaxis, ...]
def h(x: Array2D[NPFloat, One, One]) -> float:
return float(x[0][0])
a: Array1D[NPInt, One] = Array1D(np.ones(shape=(1,), dtype=NPInt))
b: Array2D[NPFloat, M, N] = Array2D(np.ones(shape=(50, 10), dtype=NPFloat))
c: Array2D[NPUInt8, Two, N] = Array2D(np.ones(shape=(1, 10), dtype=NPUInt8))
d: Array2D[NPUInt8, One, N] = Array2D(np.ones(shape=(1, 10), dtype=NPUInt8))
f(a)
# Argument 1 to "f" has incompatible type "Array1D[NPInt, Dimension[Literal[1]]]";
# expected "Array2D[NPUInt8, <nothing>, <nothing>]" [arg-type]
f(b)
# Argument 1 to "f" has incompatible type "Array2D[NPFloat, Dimension[int], Dimension[int]]";
# expected "Array2D[NPUInt8, Dimension[int], Dimension[int]]"
f(c) # OK
f(d) # OK
g(f(c)) # OK
g(f(d)) # OK
h(g(f(c)))
# Argument 1 to "f" has incompatible type "Array2D[NPUInt8, Dimension[Literal[2]], Dimension[int]]";
# expected "Array2D[NPUInt8, Dimension[Literal[1]], Dimension[int]]"
h(g(f(d))) # OK
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment