Skip to content

Instantly share code, notes, and snippets.

@altescy
Last active January 7, 2024 09:00
Show Gist options
  • Save altescy/38f58a61df413159afb7b0933c2cfef1 to your computer and use it in GitHub Desktop.
Save altescy/38f58a61df413159afb7b0933c2cfef1 to your computer and use it in GitHub Desktop.
Numpy-like tensor implementation written in pure Python.
from __future__ import annotations
import itertools
import math
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import partial, reduce
from operator import mul
from types import EllipsisType
from typing import Generic, Type, TypeAlias, TypeVar, Union, cast, overload
Numeric: TypeAlias = bool | int | float | complex
DType = TypeVar("DType", bool, int, float, complex)
T_DType = TypeVar("T_DType", bool, int, float, complex)
class Tensor(Generic[DType]):
def __init__(
self,
data: DType | Sequence[DType],
shape: Sequence[int] | None = None,
) -> None:
self._data: tuple[DType, ...] = tuple(data) if isinstance(data, Sequence) else (data,)
self._shape: tuple[int, ...] = (
tuple(shape) if shape is not None else (len(data),) if isinstance(data, Sequence) else ()
)
if not self.shape:
if self.data and len(self.data) != 1:
raise ValueError("shape must be specified if data has more than one element")
if any(k < 0 for k in self.shape):
raise ValueError("shape must be non-negative")
else:
if len(self.data) != reduce(mul, self.shape, 1):
raise ValueError("data size does not match shape")
@property
def data(self) -> tuple[DType, ...]:
return self._data
@property
def shape(self) -> tuple[int, ...]:
return self._shape
@property
def ndim(self) -> int:
return len(self.shape) if self.shape else 0
@property
def size(self) -> int:
return len(self.data)
@property
def value(self) -> DType:
if self.size != 1:
raise ValueError("tensor must have exactly one element")
return self.data[0]
@staticmethod
def get_nested_index(flattened_index: int, shape: Sequence[int]) -> tuple[int, ...]:
index = flattened_index
multi_index = []
for dim_size in reversed(shape):
multi_index.append(index % dim_size)
index //= dim_size
return tuple(reversed(multi_index))
@staticmethod
def get_flattened_index(multi_index: Sequence[int], shape: Sequence[int]) -> int:
return sum(dim_index * reduce(mul, shape[i + 1 :], 1) for i, dim_index in enumerate(multi_index))
@staticmethod
def normalize_index(
index: tuple[int | slice | Sequence[int] | EllipsisType | None, ...], shape: Sequence[int]
) -> tuple[int | slice | Sequence[int] | None, ...]:
ndim = len(shape)
num_specified_indices = sum(i not in (None, Ellipsis) for i in index)
if num_specified_indices > ndim:
raise IndexError("too many indices for tensor")
ellipsis_count = index.count(Ellipsis)
if ellipsis_count == 0:
num_slices_for_rest_dims = ndim - num_specified_indices
index = index + (slice(None),) * num_slices_for_rest_dims
elif ellipsis_count == 1:
ellipsis_index = index.index(Ellipsis)
num_slices_for_ellipsis = ndim - num_specified_indices
index = index[:ellipsis_index] + (slice(None),) * num_slices_for_ellipsis + index[ellipsis_index + 1 :]
else:
raise IndexError("only one ellipsis allowed")
return cast(tuple[int | slice | None, ...], index)
def __repr__(self) -> str:
def format_tensor(level: int, shape: Sequence[int], data: Sequence[Numeric]) -> str:
if not shape:
return str(data[0])
dim = shape[0]
extra_size = reduce(mul, shape[1:], 1)
sub_tensors = [
format_tensor(level + 1, shape[1:], data[i * extra_size : (i + 1) * extra_size]) for i in range(dim)
]
if level == self.ndim - 1:
return f"[{', '.join(sub_tensors)}]"
newline = "\n"
indent = " " * level
return f"[{newline}{(',' + newline).join(indent + ' ' + sub for sub in sub_tensors)}{newline + indent}]"
return f"Tensor({format_tensor(0, self.shape, self.data)})"
@overload
def __add__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[int]) -> Tensor[int]:
...
@overload
def __add__(self: Tensor[float], other: bool | int | float | Tensor[int] | Tensor[float]) -> Tensor[float]:
...
@overload
def __add__(self: Tensor[bool] | Tensor[int] | Tensor[float], other: float | Tensor[float]) -> Tensor[float]:
...
@overload
def __add__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]:
...
@overload
def __add__(self, other: complex | Tensor[complex]) -> Tensor[complex]:
...
def __add__(self, other: Numeric | Tensor) -> Tensor:
if isinstance(other, (int, float, complex)):
return Tensor(tuple(a + other for a in self.data), self.shape)
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a + b for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented # type: ignore[unreachable]
@overload
def __radd__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[int]) -> Tensor[int]:
...
@overload
def __radd__(self: Tensor[float], other: bool | int | float | Tensor[int] | Tensor[float]) -> Tensor[float]:
...
@overload
def __radd__(self: Tensor[bool] | Tensor[int] | Tensor[float], other: float | Tensor[float]) -> Tensor[float]:
...
@overload
def __radd__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]:
...
@overload
def __radd__(self, other: complex | Tensor[complex]) -> Tensor[complex]:
...
def __radd__(self, other: Numeric | Tensor) -> Tensor:
return self + other
@overload
def __sub__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[int]) -> Tensor[int]:
...
@overload
def __sub__(self: Tensor[float], other: bool | int | float | Tensor[int] | Tensor[float]) -> Tensor[float]:
...
@overload
def __sub__(self: Tensor[bool] | Tensor[int] | Tensor[float], other: float | Tensor[float]) -> Tensor[float]:
...
@overload
def __sub__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]:
...
@overload
def __sub__(self, other: complex | Tensor[complex]) -> Tensor[complex]:
...
def __sub__(self, other: Numeric | Tensor) -> Tensor:
if isinstance(other, (int, float, complex, Tensor)):
return self + (-other)
return NotImplemented # type: ignore[unreachable]
@overload
def __rsub__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[int]) -> Tensor[int]:
...
@overload
def __rsub__(self: Tensor[float], other: bool | int | float | Tensor[int] | Tensor[float]) -> Tensor[float]:
...
@overload
def __rsub__(self: Tensor[bool] | Tensor[int] | Tensor[float], other: float | Tensor[float]) -> Tensor[float]:
...
@overload
def __rsub__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]:
...
@overload
def __rsub__(self, other: complex | Tensor[complex]) -> Tensor[complex]:
...
def __rsub__(self, other: Numeric | Tensor) -> Tensor:
return -(self - other)
@overload
def __mul__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[int]) -> Tensor[int]:
...
@overload
def __mul__(self: Tensor[float], other: bool | int | float | Tensor[int] | Tensor[float]) -> Tensor[float]:
...
@overload
def __mul__(self: Tensor[bool] | Tensor[int] | Tensor[float], other: float | Tensor[float]) -> Tensor[float]:
...
@overload
def __mul__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]:
...
@overload
def __mul__(self, other: complex | Tensor[complex]) -> Tensor[complex]:
...
def __mul__(self, other: Numeric | Tensor) -> Tensor:
if isinstance(other, (int, float, complex)):
return Tensor(tuple(a * other for a in self.data), self.shape)
if isinstance(other, Tensor):
if self.shape != other.shape:
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a * b for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented # type: ignore[unreachable]
@overload
def __rmul__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[int]) -> Tensor[int]:
...
@overload
def __rmul__(self: Tensor[float], other: bool | int | float | Tensor[int] | Tensor[float]) -> Tensor[float]:
...
@overload
def __rmul__(self: Tensor[bool] | Tensor[int] | Tensor[float], other: float | Tensor[float]) -> Tensor[float]:
...
@overload
def __rmul__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]:
...
def __rmul__(self, other: Numeric | Tensor) -> Tensor:
return self * other
@overload
def __truediv__(
self: Tensor[bool] | Tensor[int] | Tensor[float],
other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float],
) -> Tensor[float]:
...
@overload
def __truediv__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]:
...
@overload
def __truediv__(self, other: complex | Tensor[complex]) -> Tensor[complex]:
...
def __truediv__(self, other: Numeric | Tensor) -> Tensor:
if isinstance(other, (int, float, complex)):
return Tensor(tuple(a / other for a in self.data), self.shape)
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a / b for a, b in zip(left.data, right.data)), self.shape)
return NotImplemented # type: ignore[unreachable]
@overload
def __rtruediv__(
self: Tensor[bool] | Tensor[int] | Tensor[float],
other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float],
) -> Tensor[float]:
...
@overload
def __rtruediv__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]:
...
@overload
def __rtruediv__(self, other: complex | Tensor[complex]) -> Tensor[complex]:
...
def __rtruediv__(self, other: Numeric | Tensor) -> Tensor:
if isinstance(other, (int, float, complex)):
return Tensor(tuple(other / a for a in self.data), self.shape)
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(other, self)
return Tensor(tuple(a / b for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented # type: ignore[unreachable]
@overload
def __floordiv__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]:
...
@overload
def __floordiv__(
self: Tensor[float], other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float]
) -> Tensor[float]:
...
def __floordiv__(
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Numeric | Tensor
) -> Tensor[int] | Tensor[float]:
if isinstance(other, (int, float)):
return cast(Union[Tensor[int], Tensor[float]], Tensor(tuple(a // other for a in self.data), self.shape))
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(self, other)
return cast(
Union[Tensor[int], Tensor[float]],
Tensor(tuple(a // b for a, b in zip(left.data, right.data)), left.shape),
)
return NotImplemented
@overload
def __rfloordiv__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]:
...
@overload
def __rfloordiv__(
self: Tensor[float], other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float]
) -> Tensor[float]:
...
def __rfloordiv__(
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Numeric | Tensor
) -> Tensor[int] | Tensor[float]:
if isinstance(other, (int, float)):
return cast(Union[Tensor[int], Tensor[float]], Tensor(tuple(other // a for a in self.data), self.shape))
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(other, self)
return cast(
Union[Tensor[int], Tensor[float]],
Tensor(tuple(b // a for a, b in zip(left.data, right.data)), left.shape),
)
return NotImplemented
@overload
def __mod__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]:
...
@overload
def __mod__(
self: Tensor[float], other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float]
) -> Tensor[float]:
...
def __mod__(
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Numeric | Tensor
) -> Tensor[int] | Tensor[float]:
if isinstance(other, (int, float)):
return cast(Union[Tensor[int], Tensor[float]], Tensor(tuple(a % other for a in self.data), self.shape))
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(self, other)
return cast(
Union[Tensor[int], Tensor[float]],
Tensor(tuple(a % b for a, b in zip(left.data, right.data)), left.shape),
)
return NotImplemented
@overload
def __rmod__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]:
...
@overload
def __rmod__(
self: Tensor[float], other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float]
) -> Tensor[float]:
...
def __rmod__(
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Numeric | Tensor
) -> Tensor[int] | Tensor[float]:
if isinstance(other, (int, float)):
return cast(Union[Tensor[int], Tensor[float]], Tensor(tuple(other % a for a in self.data), self.shape))
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(other, self)
return cast(
Union[Tensor[int], Tensor[float]],
Tensor(tuple(b % a for a, b in zip(left.data, right.data)), left.shape),
)
return NotImplemented
@overload
def __pow__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]:
...
@overload
def __pow__(
self: Tensor[float], other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float]
) -> Tensor[float]:
...
@overload
def __pow__(self: Tensor[complex], other: Numeric | Tensor[complex]) -> Tensor[complex]:
...
@overload
def __pow__(self, other: complex | Tensor[complex]) -> Tensor[complex]:
...
def __pow__(self, other: Numeric | Tensor) -> Tensor:
if isinstance(other, (int, float, complex)):
return Tensor(tuple(a**other for a in self.data), self.shape)
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a**b for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented # type: ignore[unreachable]
@overload
def __rpow__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]:
...
@overload
def __rpow__(
self: Tensor[float], other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float]
) -> Tensor[float]:
...
@overload
def __rpow__(self: Tensor[complex], other: Numeric | Tensor[complex]) -> Tensor[complex]:
...
@overload
def __rpow__(self, other: complex | Tensor[complex]) -> Tensor[complex]:
...
def __rpow__(self, other: Numeric | Tensor) -> Tensor:
if isinstance(other, (int, float, complex)):
return Tensor(tuple(other**a for a in self.data), self.shape)
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(other, self)
return Tensor(tuple(a**b for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented # type: ignore[unreachable]
@overload
def __matmul__(self: Tensor[bool] | Tensor[int], other: Tensor[bool] | Tensor[int]) -> Tensor[int]:
...
@overload
def __matmul__(self: Tensor[float], other: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]:
...
@overload
def __matmul__(self: Tensor[complex], other: Tensor) -> Tensor[complex]:
...
@overload
def __matmul__(self, other: Tensor[complex]) -> Tensor[complex]:
...
def __matmul__(self, other: Tensor) -> Tensor:
if not isinstance(other, Tensor):
return NotImplemented # type: ignore[unreachable]
if self.ndim > 2 or other.ndim > 2:
raise ValueError("matmul requires 2D tensors")
left, right = self, other
if left.ndim == 1:
left = left.reshape((1, left.shape[0]))
if right.ndim == 1:
right = right.reshape((right.shape[0], 1))
if left.shape[1] != right.shape[0]:
raise ValueError(f"shape mismatch: {self.shape} @ {other.shape}")
output_shape = (left.shape[0], right.shape[1])
output_data: list[Numeric] = [0] * reduce(mul, output_shape, 1)
for i in range(output_shape[0]):
for j in range(output_shape[1]):
for k in range(left.shape[1]):
output_data[i * other.shape[1] + j] += (
left.data[i * self.shape[1] + k] * right.data[k * other.shape[1] + j]
)
return Tensor(output_data, output_shape)
def __and__(self: Tensor[bool], other: Tensor[bool]) -> Tensor[bool]:
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a and b for a, b in zip(left.data, right.data)), left.shape)
def __or__(self: Tensor[bool], other: Tensor[bool]) -> Tensor[bool]:
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a or b for a, b in zip(left.data, right.data)), left.shape)
def __xor__(self: Tensor[bool], other: Tensor[bool]) -> Tensor[bool]:
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a ^ b for a, b in zip(left.data, right.data)), left.shape)
def __eq__(self, other: Numeric | Tensor) -> Tensor[bool]: # type: ignore[override]
if isinstance(other, (int, float, complex)):
return Tensor(tuple(a == other for a in self.data), self.shape)
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a == b for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented
def __req__(self, other: Numeric | Tensor) -> Tensor[bool]: # type: ignore[override]
if isinstance(other, (int, float, complex)):
return Tensor(tuple(other == a for a in self.data), self.shape)
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(other, self)
return Tensor(tuple(b == a for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented
def __ne__(self, other: Numeric | Tensor) -> Tensor[bool]: # type: ignore[override]
if isinstance(other, (int, float, complex)):
return Tensor(tuple(a != other for a in self.data), self.shape)
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a != b for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented
def __rne__(self, other: Numeric | Tensor) -> Tensor[bool]: # type: ignore[override]
if isinstance(other, (int, float, complex)):
return Tensor(tuple(other != a for a in self.data), self.shape)
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(other, self)
return Tensor(tuple(b != a for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented
@overload
def __lt__(
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Tensor[bool] | Tensor[int] | Tensor[float]
) -> Tensor[bool]:
...
@overload
def __lt__(self: Tensor[complex], other: Tensor[complex]) -> Tensor[bool]:
...
def __lt__(self, other: Numeric | Tensor) -> Tensor[bool]:
if isinstance(other, (int, float, complex)):
return Tensor(tuple(a < other for a in self.data), self.shape) # type: ignore[operator]
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a < b for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented
@overload
def __le__(
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Tensor[bool] | Tensor[int] | Tensor[float]
) -> Tensor[bool]:
...
@overload
def __le__(self: Tensor[complex], other: Tensor[complex]) -> Tensor[bool]:
...
def __le__(self, other: Numeric | Tensor) -> Tensor[bool]:
if isinstance(other, (int, float, complex)):
return Tensor(tuple(a <= other for a in self.data), self.shape) # type: ignore[operator]
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a <= b for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented
@overload
def __gt__(
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Tensor[bool] | Tensor[int] | Tensor[float]
) -> Tensor[bool]:
...
@overload
def __gt__(self: Tensor[complex], other: Tensor[complex]) -> Tensor[bool]:
...
def __gt__(self, other: Numeric | Tensor) -> Tensor[bool]:
if isinstance(other, (int, float, complex)):
return Tensor(tuple(a > other for a in self.data), self.shape) # type: ignore[operator]
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a > b for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented
@overload
def __ge__(
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Tensor[bool] | Tensor[int] | Tensor[float]
) -> Tensor[bool]:
...
@overload
def __ge__(self: Tensor[complex], other: Tensor[complex]) -> Tensor[bool]:
...
def __ge__(self, other: Numeric | Tensor) -> Tensor[bool]:
if isinstance(other, (int, float, complex)):
return Tensor(tuple(a >= other for a in self.data), self.shape) # type: ignore[operator]
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a >= b for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented
def __lshift__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]:
if isinstance(other, (int, bool)):
return Tensor(tuple(a << other for a in self.data), self.shape)
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(self, other)
return Tensor(tuple(a << b for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented
def __rshift__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]:
if isinstance(other, (int, bool)):
return Tensor(tuple(other << a for a in self.data), self.shape)
if isinstance(other, Tensor):
left, right = self.broadcast_tensors(other, self)
return Tensor(tuple(b << a for a, b in zip(left.data, right.data)), left.shape)
return NotImplemented
@overload
def __pos__(self: Tensor[bool]) -> Tensor[int]:
...
@overload
def __pos__(self: Tensor[int]) -> Tensor[int]:
...
@overload
def __pos__(self: Tensor[float]) -> Tensor[float]:
...
@overload
def __pos__(self: Tensor[complex]) -> Tensor[complex]:
...
def __pos__(self) -> Tensor:
return Tensor(tuple(+a for a in self.data), self.shape)
@overload
def __neg__(self: Tensor[bool]) -> Tensor[int]:
...
@overload
def __neg__(self: Tensor[int]) -> Tensor[int]:
...
@overload
def __neg__(self: Tensor[float]) -> Tensor[float]:
...
@overload
def __neg__(self: Tensor[complex]) -> Tensor[complex]:
...
def __neg__(self) -> Tensor:
return Tensor(tuple(-a for a in self.data), self.shape)
def __invert__(self: Tensor[bool]) -> Tensor[bool]:
return Tensor(tuple(not a for a in self.data), self.shape)
def __getitem__(
self,
idx: int | slice | Sequence[int] | Tensor[bool] | tuple[int | slice | Sequence[int] | EllipsisType | None, ...],
) -> Tensor[DType]:
if not self.shape:
raise IndexError("scalar tensor cannot be indexed")
flattened_indices: Iterable[int]
output_shape: tuple[int, ...]
extra_size = reduce(mul, self.shape[1:], 1)
if isinstance(idx, int):
if idx >= self.shape[0]:
raise IndexError("index out of range")
if idx < 0:
idx += self.shape[0]
flattened_indices = iter(range(idx * extra_size, (idx + 1) * extra_size))
output_shape = self.shape[1:]
elif isinstance(idx, slice):
start, stop, step = idx.indices(self.shape[0])
flattened_indices = itertools.chain.from_iterable(
range(i * extra_size, (i + 1) * extra_size) for i in range(start, stop, step)
)
output_shape = (math.ceil((stop - start) / step),) + self.shape[1:]
elif isinstance(idx, tuple):
current_dim = 0
output_shape = ()
index_iterators: list[Iterable[int]] = []
for output_dim, index in enumerate(self.normalize_index(idx, self.shape)):
if index is None:
output_shape = output_shape + (1,)
continue
elif isinstance(index, int):
if index >= self.shape[current_dim]:
raise IndexError("index out of range")
if index < 0:
index += self.shape[current_dim]
index_iterators.append((index,))
elif isinstance(index, slice):
start, stop, step = index.indices(self.shape[current_dim])
index_iterators.append(range(start, stop, step))
output_shape = output_shape + (math.ceil((stop - start) / step),)
elif isinstance(index, Sequence):
index_iterators.append(index)
output_shape = output_shape + (len(index),)
else:
raise IndexError("index must be int, slice, Ellipsis, or None, but got " + repr(index))
current_dim += 1
flattened_indices = map(
partial(self.get_flattened_index, shape=self.shape),
itertools.product(*index_iterators),
)
if output_shape == (1,) and all(isinstance(i, int) for i in output_shape):
output_shape = ()
elif isinstance(idx, Sequence):
idx = cast(Sequence[int], idx)
if any(i >= self.shape[0] for i in idx):
raise IndexError("index out of range")
flattened_indices = itertools.chain.from_iterable(range(i * extra_size, (i + 1) * extra_size) for i in idx)
output_shape = (len(idx),) + self.shape[1:]
elif isinstance(idx, Tensor):
if idx.shape != self.shape:
raise ValueError("shape mismatch")
flattened_indices = [i for i in range(self.size) if idx.data[i]]
output_shape = (len(flattened_indices),)
return Tensor([self.data[i] for i in flattened_indices], output_shape)
def __iter__(self) -> Iterator[Tensor[DType]]:
if not self.shape:
raise TypeError("'Tensor' object is not iterable")
for i in range(self.shape[0]):
yield self[i]
def copy(self) -> Tensor[DType]:
return Tensor(self.data, self.shape)
def reshape(self, shape: Sequence[int]) -> Tensor[DType]:
neg_count = 0
neg_index = -1
for i, s in enumerate(shape):
if s < 0:
neg_count += 1
neg_index = i
if neg_count > 1:
raise ValueError("can only specify one unknown dimension")
elif neg_count == 1:
shape = list(shape)
shape[neg_index] = -1
shape[neg_index] = self.size // -reduce(mul, shape, 1)
if reduce(mul, shape, 1) != self.size:
raise ValueError("cannot reshape tensor of size {} into shape {}".format(self.shape, shape))
return Tensor(self.data, shape)
def transpose(self, dim0: int, dim1: int) -> Tensor[DType]:
if not (0 <= dim0 < self.ndim and 0 <= dim1 < self.ndim):
raise ValueError("dimension out of range")
if dim0 == dim1:
return self.copy()
transposed_shape = list(self.shape)
transposed_shape[dim0], transposed_shape[dim1] = transposed_shape[dim1], transposed_shape[dim0]
def get_transposed_index(i: int) -> int:
multi_index = list(self.get_nested_index(i, self.shape))
multi_index[dim0], multi_index[dim1] = multi_index[dim1], multi_index[dim0]
transposed_index = self.get_flattened_index(multi_index, transposed_shape)
return transposed_index
transposed_data = list(self.data)
for original_index, transposed_index in enumerate(map(get_transposed_index, range(self.size))):
transposed_data[transposed_index] = self.data[original_index]
return Tensor(transposed_data, transposed_shape)
def is_equal(self, other: object) -> bool:
if not isinstance(other, Tensor):
return NotImplemented
return self.shape == other.shape and self.data == other.data
@overload
def sum(self: Tensor[bool] | Tensor[int], dim: int | None = ...) -> Tensor[int]:
...
@overload
def sum(self: Tensor[float], dim: int | None = ...) -> Tensor[float]:
...
@overload
def sum(self: Tensor[complex], dim: int | None = ...) -> Tensor[complex]:
...
def sum(self, dim: int | None = None) -> Tensor:
if dim is None:
return Tensor(sum(self.data))
if not (0 <= dim < self.ndim):
raise ValueError("dimension out of range")
output: Tensor = Tensor(0)
for i in range(self.shape[dim]):
index: list[slice | int] = [slice(None)] * self.ndim
index[dim] = i
output += self[tuple(index)]
return output
def apply(self, fn: Callable[[DType], T_DType]) -> Tensor[T_DType]:
return Tensor(tuple(map(fn, self.data)), self.shape)
def exp(self: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]:
return self.apply(math.exp)
def expm1(self: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]:
return self.apply(math.expm1)
def log(self: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]:
return self.apply(math.log)
def log10(self: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]:
return self.apply(math.log10)
def log1p(self: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]:
return self.apply(math.log1p)
def log2(self: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]:
return self.apply(math.log2)
@overload
def astype(self, dtype: Type[bool]) -> Tensor[bool]: # type: ignore[misc]
...
@overload
def astype(self: Tensor[bool] | Tensor[int] | Tensor[float], dtype: Type[int]) -> Tensor[int]:
...
@overload
def astype(self: Tensor[bool] | Tensor[int] | Tensor[float], dtype: Type[float]) -> Tensor[float]:
...
@overload
def astype(self, dtype: Type[complex]) -> Tensor[complex]:
...
def astype(self, dtype: Type[Numeric]) -> Tensor[bool] | Tensor[int] | Tensor[float] | Tensor[complex]:
return Tensor(tuple(dtype(a) for a in self.data), self.shape) # type: ignore[arg-type, misc, call-overload]
def broadcast_to(self, shape: Sequence[int]) -> Tensor[DType]:
if self.shape == shape:
return self.copy()
if len(self.shape) > len(shape):
raise ValueError("cannot broadcast tensor of shape {} to shape {}".format(self.shape, shape))
if self.shape == ():
return Tensor([self.data[0]] * reduce(mul, shape, 1), shape)
original_shape = [1] * (len(shape) - len(self.shape)) + list(self.shape)
broadcasted_shape = shape
index_iterators: list[Iterable[int]] = []
for original_dim_size, broadcasted_dim_size in zip(original_shape, broadcasted_shape):
if original_dim_size == broadcasted_dim_size:
index_iterators.append(range(original_dim_size))
elif original_dim_size == 1:
if broadcasted_dim_size < 1:
raise ValueError("cannot broadcast tensor of shape {} to shape {}".format(self.shape, shape))
index_iterators.append([0] * broadcasted_dim_size)
else:
raise ValueError("cannot broadcast tensor of shape {} to shape {}".format(self.shape, shape))
broadcasted_indices = map(
partial(self.get_flattened_index, shape=self.shape),
itertools.product(*index_iterators),
)
return Tensor([self.data[i] for i in broadcasted_indices], broadcasted_shape)
@staticmethod
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]:
max_ndim = max(len(shape) for shape in shapes)
shapes = tuple((1,) * (max_ndim - len(shape)) + shape for shape in shapes)
output_shape = []
for dim_sizes in zip(*shapes):
if len(set(dim_sizes) - {1}) > 1:
raise ValueError("cannot broadcast shapes {} to a common shape".format(shapes))
output_shape.append(max(dim_sizes))
return tuple(output_shape)
@staticmethod
def broadcast_tensors(*tensors: Tensor) -> tuple[Tensor, ...]:
output_shape = Tensor.broadcast_shapes(*[tensor.shape for tensor in tensors])
return tuple(tensor.broadcast_to(output_shape) for tensor in tensors)
@classmethod
def zeros(cls, shape: Sequence[int]) -> Tensor[int]:
return Tensor([0] * reduce(mul, shape, 1), shape)
@classmethod
def ones(cls, shape: Sequence[int]) -> Tensor[int]:
return Tensor([1] * reduce(mul, shape, 1), shape)
@classmethod
def eye(cls, n: int) -> Tensor[int]:
return Tensor([1 if i == j else 0 for i in range(n) for j in range(n)], (n, n))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment