Skip to content

Instantly share code, notes, and snippets.

@piojanu
Last active March 1, 2023 10:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save piojanu/1cd488bc1ec74059e7dc09449c85de33 to your computer and use it in GitHub Desktop.
Save piojanu/1cd488bc1ec74059e7dc09449c85de33 to your computer and use it in GitHub Desktop.
Wrapper around a PyTorch sparse tensor that allows quick access to the values at the first dimension.
import functools
from typing import Any, Sequence, Tuple, Union
import torch as th
HANDLED_FUNCTIONS = {}
# TODO: Properly type annotate.
def implements(torch_function: Any) -> Any:
"""Register a torch function override for QuickAccessSparseTensor"""
@functools.wraps(torch_function)
# TODO: Properly type annotate.
def decorator(func: Any) -> Any:
HANDLED_FUNCTIONS[torch_function] = func
return func
return decorator
@th.jit.script
def _index_select_torchscript(
indices: th.Tensor, values: th.Tensor, index: th.Tensor, bounds: th.Tensor
) -> Tuple[th.Tensor, th.Tensor]:
indices_list = []
values_list = []
for output_idx, input_idx in enumerate(index):
start, end = bounds[input_idx], bounds[input_idx + 1]
indices_ = indices[:, start:end]
indices_[0, :] = output_idx
indices_list.append(indices_)
values_list.append(values[start:end])
return (
th.cat(indices_list, dim=1) if indices_list else th.empty((2, 0), dtype=th.long).to(indices),
th.cat(values_list, dim=0) if values_list else th.empty((0,), dtype=values.dtype).to(values),
)
class QuickAccessSparseTensor:
"""Wrapper around a sparse tensor that allows quick access to the values at the first dimension.
ASSUMPTIONS:
1. The indices are coalesced. If you create it from `torch.sparse_coo_tensor` and then only use operations
implemented here, it will remain coalesced.
2. It doesn't support indexing other dimensions than the first one.
"""
def __init__(self, indices: th.LongTensor, values: th.Tensor, shape: th.Size):
self.indices = indices
self.values = values
self.shape = shape
# Check if the indices are sorted by the first dimension and sort them if not.
if not self.indices[0, 1:].ge(self.indices[0, :-1]).all():
self._sort_indices_by_the_first_dimension()
self._bounds = self._calculate_bounds()
def _sort_indices_by_the_first_dimension(self) -> None:
sorting_by_first_dim = self.indices[0, :].argsort()
self.indices = self.indices[:, sorting_by_first_dim]
self.values = self.values[sorting_by_first_dim]
def _calculate_bounds(self) -> th.LongTensor:
bounds = th.zeros(self.shape[0] + 1, dtype=th.long).to(self.indices)
bounds[1:] = th.bincount(self.indices[0], minlength=self.shape[0])
bounds.cumsum_(dim=0)
return bounds
def index_select(self, dim: int, index: th.LongTensor) -> "QuickAccessSparseTensor":
assert dim == 0, "Only dim = 0 is supported"
indices, values = _index_select_torchscript(self.indices, self.values, index, self._bounds)
return QuickAccessSparseTensor(
indices=indices,
values=values,
shape=th.Size([len(index), *self.shape[1:]]),
)
def size(self, dim: int = None) -> Union[th.Size, int]:
if dim is None:
return self.shape
else:
return self.shape[dim]
def to(self, *args: Any, **kwargs: Any) -> "QuickAccessSparseTensor":
return QuickAccessSparseTensor(
indices=self.indices.to(*args, **kwargs),
values=self.values.to(*args, **kwargs),
shape=self.shape,
)
def to_dense(self) -> th.Tensor:
tensor = th.zeros(self.shape, dtype=self.values.dtype).to(self.values)
tensor[tuple(self.indices)] = self.values
return tensor
def to_sparse(self) -> th.sparse_coo_tensor:
return th.sparse_coo_tensor(self.indices, self.values, self.shape)
@classmethod
def from_sparse_coo_tensor(cls, sparse_tensor: th.sparse_coo_tensor) -> "QuickAccessSparseTensor":
sparse_tensor = sparse_tensor.coalesce()
return cls(
sparse_tensor.indices(),
sparse_tensor.values(),
sparse_tensor.size(),
)
# TODO: Properly type annotate.
@classmethod
def __torch_function__(cls, func: Any, types: Any, args: Any = (), kwargs: Any = None) -> Any:
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(issubclass(t, QuickAccessSparseTensor) for t in types):
return NotImplementedError()
return HANDLED_FUNCTIONS[func](*args, **kwargs)
@implements(th.cat)
def cat(tensors: Sequence[QuickAccessSparseTensor], dim: int = 0, *, out: th.Tensor = None) -> QuickAccessSparseTensor:
assert out is None, "Output tensor is not supported."
assert all(
tensor.shape[:dim] == tensors[0].shape[:dim]
and tensor.shape[dim + 1 :] == tensors[0].shape[dim + 1 :] # noqa: E203
for tensor in tensors
), "All tensors must have the same shape except for the dimension to concatenate."
# Concatenate indices.
indices_list = []
for idx, tensor in enumerate(tensors):
indices_ = tensor.indices.clone()
# Shift the indices at the concatenating dimension into place in the new tensor.
indices_[dim, :] += sum(tensor.shape[dim] for tensor in tensors[:idx])
indices_list.append(indices_)
indices = th.cat(indices_list, dim=1)
# Concatenate values.
values = th.cat([tensor.values for tensor in tensors], dim=0)
# Calculate the new shape.
shape = list(tensors[0].shape)
shape[dim] = sum(tensor.shape[dim] for tensor in tensors)
shape = th.Size(shape)
return QuickAccessSparseTensor(indices, values, shape)
@implements(th.index_select)
def index_select(tensor: QuickAccessSparseTensor, dim: int, index: th.LongTensor) -> QuickAccessSparseTensor:
return tensor.index_select(dim, index)
import pytest
import torch as th
from sparse_pytorch_tensor import QuickAccessSparseTensor
@pytest.fixture
def sparse_tensor() -> th.sparse_coo_tensor:
return th.sparse_coo_tensor(
indices=[[0, 2, 2, 2], [0, 1, 2, 3]],
values=th.ones(4, dtype=bool),
size=(3, 4),
dtype=bool,
)
@pytest.mark.parametrize("index", [th.tensor([0]), th.tensor([1]), th.tensor([2]), th.tensor([0, 1, 2])])
def test_quick_access_sparse_tensor_index_select(sparse_tensor: th.sparse_coo_tensor, index: th.LongTensor) -> None:
# given
qa_sparse_tensor = QuickAccessSparseTensor.from_sparse_coo_tensor(sparse_tensor)
# then
assert th.equal(
th.index_select(qa_sparse_tensor, 0, index).to_dense(), th.index_select(sparse_tensor, 0, index).to_dense()
)
def test_quick_access_sparse_tensor_index_select_out_of_bounds(sparse_tensor: th.sparse_coo_tensor) -> None:
# given
qa_sparse_tensor = QuickAccessSparseTensor.from_sparse_coo_tensor(sparse_tensor)
# then
with pytest.raises(RuntimeError):
qa_sparse_tensor.index_select(0, th.tensor([3]))
@pytest.mark.parametrize("dim", [0, 1])
def test_quick_access_sparse_tensors_concatenation(sparse_tensor: th.sparse_coo_tensor, dim: int) -> None:
# given
qa_sparse_tensor = QuickAccessSparseTensor.from_sparse_coo_tensor(sparse_tensor)
# when
qa_sparse_tensors_concat = th.cat([qa_sparse_tensor, qa_sparse_tensor, qa_sparse_tensor], dim=dim)
sparse_tensors_concat = th.cat([sparse_tensor, sparse_tensor, sparse_tensor], dim=dim)
# then
assert qa_sparse_tensors_concat.shape == sparse_tensors_concat.size()
assert th.equal(qa_sparse_tensors_concat.to_dense(), sparse_tensors_concat.to_dense())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment