Skip to content

Instantly share code, notes, and snippets.

@amjames
Created November 29, 2022 23:18
Show Gist options
  • Save amjames/a7c09a5c2449dc8be18d41ce9a78f0b4 to your computer and use it in GitHub Desktop.
Save amjames/a7c09a5c2449dc8be18d41ce9a78f0b4 to your computer and use it in GitHub Desktop.
import torch
import torch._dynamo as torchdynamo
torchdynamo.config.verbose = True
torchdynamo.config.suppress_errors = True
class TestCasePlaceholder:
def assertTrue(self, v):
assert v, "Expected True"
def assertFalse(self, v):
assert not v, "Expected False"
DEVICE = 'cpu'
DTYPE = torch.float64
SELF = TestCasePlaceholder()
@torchdynamo.optimize()
def test_coalesce_reference_cycle(self, device, dtype):
# Test coalesce doesn't create autograd graph cycles (gh-52253)
# Sanity check that the helper class works as expected
t = torch.rand(2)
t_ref = torch._C._WeakTensorRef(t)
self.assertFalse(t_ref.expired())
del t
self.assertTrue(t_ref.expired())
def test_sparse_sum():
i = torch.tensor([[0], [4]], dtype=torch.long, device=device)
v = torch.tensor([[[-0.4567, -1.8797, 0.0380, 1.4316]]],
dtype=dtype, device=device)
S = torch.sparse_coo_tensor(i, v)
S = S.coalesce()
S.requires_grad_(True)
S2 = S.coalesce()
self.assertTrue(S2.is_coalesced())
return torch._C._WeakTensorRef(S2)
ref = test_sparse_sum()
self.assertTrue(ref.expired())
test_coalesce_reference_cycle(SELF, DEVICE, DTYPE)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment