- saving graph traces (for future reexecution)
- full control over buffers
- breaking up the graph when it's not needed anymore
- POSSIBLY: lazy execution
- POSSIBLY: sub-graphs compilation (CUDA kernel fusion + no GIL execution)
class Add(Function):
@staticmethod
def forward(ctx, input1, scalar, input2): # supports arbitrary arguments
ctx.scalar = scalar
return input1 + scalar * input2
@staticmethod
def backward(ctx, grad_output_var):
return grad_output_var, ctx.scalar * grad_output_var
gets transformed into
class Function(CFunction):
@classmethod
def apply(cls, *args, **kwargs): # this is used to evaluate the function
ctx = new_ctx()
outputs = cls.forward(ctx, *args, **kwargs)
return cls.wrap_output(ctx, outputs, AddBackward)
class MultiplyAdd(Function):
@staticmethod
def forward(ctx, input1, scalar, input2):
ctx.scalar = scalar
return input1 + scalar * input2
@staticmethod
def backward(ctx, grad_output_var):
return grad_output_var, ctx.scalar * grad_output_var
class AddBackward(CFunction):
def __init__(self, ctx):
self.ctx = ctx
def apply(self, *args, **kwargs):
return Add.backward(ctx, *args, **kwargs)
add = Add.apply # can be used like a regular Python function
NOTE: forward()
won't recursively search the arguments - e.g. Variables
inside a dict
won't be found.
This is nice because it's very very similar to the API used at the moment.
class Add(Function):
@staticmethod
def forward(ctx, input1, scalar, input2):
return scipy.fn(input1, scalar, input2)
@staticmethod
@differentiable_once
def backward(ctx, grad_output_tensor):
return scipy.fn(grad_output_tensor)
possible implementation:
class NonDifferentiable(CFunction):
def apply():
raise RuntimeError("non differentiable function! remove @differentiable_once!")
def differentiable_once(fn):
def wrapper(ctx, *args, **kwargs):
targs, tkwargs = ... # iterate over all arguments and unpack them into tensors
outputs = fn(ctx, *targs, **tkwargs)
return Function.wrap_output(None, outputs, NonDifferentiable)
grad
-None
, unless holds a (sparse or dense) grad. In most cases freed after every forward inmodel.zero_grad()
oroptimizer.zero_grad()
grad_fn
- graph trace of the derivative.None
ifvolatile=True
, or the Variable is a leaf.creator
-None
, unless the global switch is on. Contains a graph trace of the function that created the Variable. The computation could be replayed, analyzed, displayed or compiled.
class Variable(...):
def __init__(self, data, requires_grad=False, volatile=False):
pass
with torch.autograd.save_traces():
... # all Variables in that context will have their .creators saved
- saving graph traces -
Variable.creator
- full control over buffers - everything is contained in
ctx
. - breaking up the graph when it's not needed anymore - no problems here
- POSSIBLY: lazy execution - see point 1.
- POSSIBLY: sub-graphs compilation (CUDA kernel fusion + no GIL execution) - see point 1.
- graph compression - if we have D -> S -> D computation (D=deterministic, S=stochastic subgraph), we can compress the second D part into a single node (in the backward graph) if it doesn't require grad
Context syntax:
# forward
ctx += [t1, t2]
ctx.non_differentiable += (t3,)
# backward
t1, t2 = ctx
or a more verbose one
# forward
ctx.t1 = t1
ctx.t2 = t2
ctx.mark_non_differentiable(t3)
# backward
return grad_output * ctx.t1 + ctx.t2
Just an idea, but is it possible to just operate on pure Python/C functions instead of member functions of classes, since each Function has a static forwards and backwards anyway? There might be a bit of speed savings, since we're just passing function pointers around now.