Instantly share code, notes, and snippets.

@apaszke /autograd_refactor.md Secret
Last active Jul 16, 2017

Embed
What would you like to do?

Wanted functionality:

  • saving graph traces (for future reexecution)
  • full control over buffers
  • breaking up the graph when it's not needed anymore
  • POSSIBLY: lazy execution
  • POSSIBLY: sub-graphs compilation (CUDA kernel fusion + no GIL execution)

Function implementation

class Add(Function):
    @staticmethod
    def forward(ctx, input1, scalar, input2): # supports arbitrary arguments
        ctx.scalar = scalar
        return input1 + scalar * input2

    @staticmethod
    def backward(ctx, grad_output_var):
        return grad_output_var, ctx.scalar * grad_output_var

gets transformed into

class Function(CFunction):

    @classmethod
    def apply(cls, *args, **kwargs): # this is used to evaluate the function
        ctx = new_ctx()
        outputs = cls.forward(ctx, *args, **kwargs)
        return cls.wrap_output(ctx, outputs, AddBackward)


class MultiplyAdd(Function):

    @staticmethod
    def forward(ctx, input1, scalar, input2):
        ctx.scalar = scalar
        return input1 + scalar * input2

    @staticmethod
    def backward(ctx, grad_output_var):
        return grad_output_var, ctx.scalar * grad_output_var

    class AddBackward(CFunction):

        def __init__(self, ctx):
            self.ctx = ctx

        def apply(self, *args, **kwargs):
            return Add.backward(ctx, *args, **kwargs)


add = Add.apply # can be used like a regular Python function

NOTE: forward() won't recursively search the arguments - e.g. Variables inside a dict won't be found.


Functions differentiable only once

This is nice because it's very very similar to the API used at the moment.

class Add(Function):
    @staticmethod
    def forward(ctx, input1, scalar, input2):
        return scipy.fn(input1, scalar, input2)

    @staticmethod
    @differentiable_once
    def backward(ctx, grad_output_tensor):
        return scipy.fn(grad_output_tensor)

possible implementation:

class NonDifferentiable(CFunction):
    def apply():
        raise RuntimeError("non differentiable function! remove @differentiable_once!")


def differentiable_once(fn):
    def wrapper(ctx, *args, **kwargs):
        targs, tkwargs = ... # iterate over all arguments and unpack them into tensors
        outputs = fn(ctx, *targs, **tkwargs)
        return Function.wrap_output(None, outputs, NonDifferentiable)

Variable fields

  • grad - None, unless holds a (sparse or dense) grad. In most cases freed after every forward in model.zero_grad() or optimizer.zero_grad()
  • grad_fn - graph trace of the derivative. None if volatile=True, or the Variable is a leaf.
  • creator - None, unless the global switch is on. Contains a graph trace of the function that created the Variable. The computation could be replayed, analyzed, displayed or compiled.
class Variable(...):
    def __init__(self, data, requires_grad=False, volatile=False):
      pass
with torch.autograd.save_traces():
    ... # all Variables in that context will have their .creators saved

Wanted functionality:

  • saving graph traces - Variable.creator
  • full control over buffers - everything is contained in ctx.
  • breaking up the graph when it's not needed anymore - no problems here
  • POSSIBLY: lazy execution - see point 1.
  • POSSIBLY: sub-graphs compilation (CUDA kernel fusion + no GIL execution) - see point 1.

Possible extensions:

  • graph compression - if we have D -> S -> D computation (D=deterministic, S=stochastic subgraph), we can compress the second D part into a single node (in the backward graph) if it doesn't require grad

Context syntax:

# forward
ctx += [t1, t2]
ctx.non_differentiable += (t3,)

# backward
t1, t2 = ctx

or a more verbose one

# forward
ctx.t1 = t1
ctx.t2 = t2
ctx.mark_non_differentiable(t3)

# backward
return grad_output * ctx.t1 + ctx.t2
@ebetica

This comment has been minimized.

ebetica commented Mar 7, 2017

Just an idea, but is it possible to just operate on pure Python/C functions instead of member functions of classes, since each Function has a static forwards and backwards anyway? There might be a bit of speed savings, since we're just passing function pointers around now.

@apaszke

This comment has been minimized.

Owner

apaszke commented Mar 25, 2017

Sorry didn't get a notification about your comment 😕

Yeah you can do that. It's even in one of the snippets:

add = Add.apply # can be used like a regular Python function

Still, I benchmarked the overhead of object instantiation + calling a method and just a call to a static method, and it seemed that the perf is very similar in both cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment