Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Last active October 27, 2023 19:01
Show Gist options
  • Save soulitzer/e03f015b314c3f5fcf80888c69390913 to your computer and use it in GitHub Desktop.
Save soulitzer/e03f015b314c3f5fcf80888c69390913 to your computer and use it in GitHub Desktop.
# Technically even in the "easy case" of t._base.requires_grad == t.requires_grad
# I need to perform two views to recreate that view authentically. why?
# There are actually two things I need to recreate, (1) the autograd
# graph relationship and (2) the view relationship.
# The reason we don't handle this today is because this autograd connectivity information
# is not accessible during tracing and hence not relevant to compile in part because dynam
# doesn't support grad_fn access.
for requires_grad_1, requires_grad_2 in [(True, False), (False, True), (True, True), (False, False)]:
a = torch.tensor([1.], requires_grad=requires_grad_1) # leaf
b = a[:] # view
c = a.clone() # non-leaf, non-view
for x in (a, b, c):
with torch.no_grad():
x_view = x[:]
assert x_view._is_view()
assert x_view.is_leaf
assert x_view.requires_grad == requires_grad_1
rg_before = x_view.requires_grad
# !! Doing requires_grad_ here has very strange behavior
x_view.requires_grad_(requires_grad_2)
if rg_before:
assert x_view.requires_grad == rg_before
else:
assert x_view.requires_grad == requires_grad_2
x_view_view = x_view[:]
if requires_grad_2:
# (1) autograd info: points to the intermediate view
assert x_view_view.grad_fn.next_functions[0][0].variable is x_view
else:
if requires_grad_1:
assert x_view_view.grad_fn is not None
assert x_view_view.grad_fn.next_functions[0][0] is None
else:
assert x_view_view.grad_fn is None
# (2) view info: points to the original base
if x._is_view():
assert x_view_view._base is a
else:
assert x_view_view._base is x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment