Skip to content

Instantly share code, notes, and snippets.

@amjames
Created November 29, 2022 15:39
Show Gist options
  • Save amjames/446a8d515acf4b32392781a64e10eea2 to your computer and use it in GitHub Desktop.
Save amjames/446a8d515acf4b32392781a64e10eea2 to your computer and use it in GitHub Desktop.
"""
Prior to adding the checks which error out when to_dense is called with hybrid (having dense dims) tensors
all of these examples triggered a segfault
The commented out lines trigger the error now, previously segfault
"""
import torch
dense = torch.randn(3, 4, dtype=torch.float32)
dense = dense * dense.relu().bool()
csr = dense.to_sparse_csr()
hybrid_values = torch.randn(csr.values().shape + (3, 3), dtype=torch.float32)
hybrid_csr = torch.sparse_compressed_tensor(
csr.crow_indices(),
csr.col_indices(),
hybrid_values,
csr.shape + (3, 3),
layout=torch.sparse_csr,
dtype=torch.float32)
print("hybrid_bsc:")
print(hybrid_csr)
#hybrid_csr.to_dense()
hybrid_csc = torch.sparse_compressed_tensor(
csr.crow_indices(),
csr.col_indices(),
hybrid_values,
csr.shape[::-1] + (3, 3),
layout=torch.sparse_csc,
dtype=torch.float32)
print("hybrid_csc:")
print(hybrid_csc)
#hybrid_csc.to_dense()
bsr = dense.to_sparse_bsr((1, 1))
hybrid_blocked_values = torch.randn(bsr.values().shape + (3, 3), dtype=torch.float32)
hybrid_bsr = torch.sparse_compressed_tensor(
bsr.crow_indices(),
bsr.col_indices(),
hybrid_blocked_values,
bsr.shape + (3, 3),
layout=torch.sparse_bsr,
dtype=torch.float32)
print("hybrid_bsr:")
print(hybrid_bsr)
#hybrid_bsr.to_dense()
hybrid_bsc = torch.sparse_compressed_tensor(
bsr.crow_indices(),
bsr.col_indices(),
hybrid_blocked_values,
bsr.shape[::-1] + (3, 3),
layout=torch.sparse_bsc,
dtype=torch.float32)
print("hybrid_bsc:")
print(hybrid_bsc)
#hybrid_bsc.to_dense()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment