Skip to content

Instantly share code, notes, and snippets.

@persiyanov
Created February 8, 2019 12:41
Show Gist options
  • Save persiyanov/9454dc94a08b64337fcc1629d316e87e to your computer and use it in GitHub Desktop.
Save persiyanov/9454dc94a08b64337fcc1629d316e87e to your computer and use it in GitHub Desktop.
pytorch masked matmul with sparse mask
import torch
import torch.autograd
class MaskedSpMatmul(torch.autograd.Function):
CHUNK_SIZE = 10000
@staticmethod
def forward(ctx, a, b, mask):
"""
a: tensor N x M
b: tensor M x K
mask: tensor 2 x L
output: sparse tensor: N x K where only L values are nonzero (specified by mask)
"""
N, M, K, L = a.shape[0], a.shape[1], b.shape[1], mask.shape[1]
ctx.save_for_backward(a, b, mask)
values = torch.zeros(L, dtype=a.dtype)
for idx in range(0, L, MaskedSpMatmul.CHUNK_SIZE):
batch_indices = mask[:, idx:idx+MaskedSpMatmul.CHUNK_SIZE]
a_batch = torch.index_select(a, 0, batch_indices[0, :])
b_batch = torch.index_select(b, 1, batch_indices[1, :]).t()
dot_prods = torch.einsum('ij,ij->i', [a_batch, b_batch])
values[idx:idx+MaskedSpMatmul.CHUNK_SIZE] = dot_prods
return torch.sparse_coo_tensor(mask, values, size=(N, K), dtype=values.dtype)
@staticmethod
def backward(ctx, grad_output):
"""
grad_output: tensor N x K
mask: sparse tensor N x K
grad_a = (grad_output * mask).mm(b.t()) : tensor N x M
grad_b = a.t().mm(grad_output * mask) : tensor M x K
"""
a, b, mask = ctx.saved_tensors
N, M, K = a.shape[0], a.shape[1], b.shape[1]
mask_dense = (
torch.sparse_coo_tensor(
mask, torch.ones(mask.shape[1]), size=(N, K), dtype=a.dtype
)
.to_dense()
)
grad_a = (grad_output * mask_dense).mm(b.t())
grad_b = a.t().mm(grad_output * mask_dense)
return grad_a, grad_b, None
class MaskedSpMatmulForTest(MaskedSpMatmul):
@staticmethod
def forward(ctx, a, b, mask):
# torch.autograd.gradcheck can't work with sparse tensors
return MaskedSpMatmul.forward(ctx, a, b, mask).to_dense()
def test_backward_correctness():
func = MaskedSpMatmulForTest().apply
a = torch.randn((50, 30), dtype=torch.float64, requires_grad=True)
b = torch.randn((30, 60), dtype=torch.float64, requires_grad=True)
mask = torch.randint(30, (2, 15))
torch.autograd.gradcheck(func, (a, b, mask), atol=1e-4)
def test_forward_correctness():
N, M, K = 50, 30, 60
nnz = 50
a = torch.randn((N, M), dtype=torch.float32)
b = torch.randn((M, K), dtype=torch.float32)
mask = torch.randint(min(N, K), (2, nnz))
mask_dense = torch.sparse_coo_tensor(
mask, torch.ones(mask.shape[1]), size=(N, K), dtype=torch.float32
).to_dense()
expected = torch.mm(a, b) * mask_dense
got = MaskedSpMatmulForTest().apply(a, b, mask)
assert torch.all(torch.lt(torch.abs(got-expected), 1e-5))
if __name__ == '__main__':
test_backward_correctness()
test_forward_correctness()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment