Created
November 3, 2023 00:51
-
-
Save soulitzer/b543a23b99673a9bb8613031f8ef5065 to your computer and use it in GitHub Desktop.
Edge case if we try to patch inference-ness in ADInplaceOrView
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class T(torch.Tensor): | |
def __new__(cls, elem): | |
return torch.Tensor._make_wrapper_subclass(cls, elem.shape, dtype=elem.dtype) | |
def __init__(self, elem): | |
self.elem = elem | |
@classmethod | |
def __torch_dispatch__(cls, func, types, args, kwargs=None): | |
if func is torch.ops.aten.detach.default: | |
(self,) = args | |
# Return without wrapping! | |
return torch.ops.aten.detach.default(self.elem) | |
raise NotImplementedError(f"{func}") | |
# Inner tensor is non-inference | |
inner = torch.randn(3) | |
with torch.inference_mode(): | |
# Subclass wrapper is inference tensor | |
x = T(inner) | |
y = x.detach() | |
print(y.is_inference()) # False | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment