Skip to content

Instantly share code, notes, and snippets.

@apaszke
Created April 5, 2017 21:19
Show Gist options
  • Save apaszke/ea1797cccef367f78981a44d174e9393 to your computer and use it in GitHub Desktop.
Save apaszke/ea1797cccef367f78981a44d174e9393 to your computer and use it in GitHub Desktop.
From 62adfff8c43c3dcd8a175fc7072f3d76d40b0a10 Mon Sep 17 00:00:00 2001
From: Adam Paszke <adam.paszke@gmail.com>
Date: Wed, 5 Apr 2017 14:17:49 -0700
Subject: [PATCH] tmp
---
torch/autograd/function.py | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
diff --git a/torch/autograd/function.py b/torch/autograd/function.py
index 8ad6805..f34af71 100644
--- a/torch/autograd/function.py
+++ b/torch/autograd/function.py
@@ -187,9 +187,15 @@ def once_differentiable(fn):
tensor_args = [arg.data if isinstance(arg, Variable) else arg
for arg in args]
outputs = fn(ctx, *tensor_args)
+ volatile = any(arg.volatile if isinstance(arg, Variable) else False
+ for arg in args)
+ if volatile:
+ kwargs = {'volatile': True}
+ else:
+ kwargs = {'_grad_fn': Error()}
if not isinstance(outputs, tuple):
- return Variable(outputs, _grad_fn=Error()) if outputs is not None else None
- return tuple([Variable(o, _grad_fn=Error()) if o is not None else None
+ return Variable(outputs, **kwargs) if outputs is not None else None
+ return tuple([Variable(o, **kwargs) if o is not None else None
for o in outputs])
return wrapper
--
2.9.3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment