Skip to content

Instantly share code, notes, and snippets.

@jxmorris12
Created March 4, 2024 21:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jxmorris12/4e51b6fae36b7948643d4a1afbd48a63 to your computer and use it in GitHub Desktop.
Save jxmorris12/4e51b6fae36b7948643d4a1afbd48a63 to your computer and use it in GitHub Desktop.
pytorch sparse tensor slice
import torch
def slice_sparse_tensor_rows(t: torch.sparse.Tensor, min_row: int, max_row: int) -> torch.sparse.Tensor:
row_idxs = t.indices()[0]
index_mask = (min_row <= row_idxs) & (row_idxs < max_row)
num_rows = (max_row - min_row)
num_cols = t.shape[1]
idxs = t.indices()[:, index_mask]
vals = t.values()[index_mask]
return torch.sparse_coo_tensor(idxs, vals, size=(num_rows, num_cols)).coalesce()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment