Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Created November 3, 2023 00:51
Show Gist options
  • Save soulitzer/b543a23b99673a9bb8613031f8ef5065 to your computer and use it in GitHub Desktop.
Save soulitzer/b543a23b99673a9bb8613031f8ef5065 to your computer and use it in GitHub Desktop.
Edge case if we try to patch inference-ness in ADInplaceOrView
import torch
class T(torch.Tensor):
def __new__(cls, elem):
return torch.Tensor._make_wrapper_subclass(cls, elem.shape, dtype=elem.dtype)
def __init__(self, elem):
self.elem = elem
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func is torch.ops.aten.detach.default:
(self,) = args
# Return without wrapping!
return torch.ops.aten.detach.default(self.elem)
raise NotImplementedError(f"{func}")
# Inner tensor is non-inference
inner = torch.randn(3)
with torch.inference_mode():
# Subclass wrapper is inference tensor
x = T(inner)
y = x.detach()
print(y.is_inference()) # False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment