Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?

Experiments with Reverse Mode Auto-Differentiation

Auto Differentiation is a technique used to calculate gradients of arbitrary computer programs. As opposed to symbolic differentiation, which occasionally results in an exponential blow-up in the size of the programs, and numerical differentiation, which estimates the gradient by running the target program dozens or hundreds of times, auto differentiation allows you to get out the gradient of a program after a single pass.

Reverse Mode Auto-Differentiation, especially in its imperative form has recently gained popularity due to projects like TF Eager, PyTorch, and HIPS Autograd. Existing auto differentiation libraries exploit operator overloading capabilities found in many languages to create data structures that incrementally track gradients.

Javascript lacks operator overloads, so defining special data structures loses much of its natural appeal. Rather than thinking about data structures, we can think about functions and how they compose, and how that affects their gradients.

Ideally we could come up with an auto-differentiation engine which was itself completely oblivious to the notion of numbers and arithmetic, or any data structure at all. Each atomic primitive function would be able to compute its value, mapping somehow its arguments to its output, as well as the gradient, a set of functions mapping a value gradient into each argument gradient.

This potentially opens the door for new generalizations of differentiation on general computer programs beyond the classical numerical data types, and potentially into the world of strings and objects.

// Here we define primitive operations with known gradients.
// User defined functions built on top of these primitives
// can have their gradients automatically defined

const sum = x => x.reduce((a, b) => a + b, 0) // helper
const Add = (a, b) => [ a + b, [ g => sum(g), g => sum(g) ]]
const Subtract = (a, b) => [ a - b, [ g => sum(g), g => -sum(g) ]]
const Multiply = (a, b) => [ a * b, [ g => b * sum(g), g => a * sum(g) ]]
const Divide = (a, b) => [ a / b, [ g => sum(g)/b, g => -a * sum(g)/(b*b) ]]
const Power = (a, b) => [a**b, [ g => sum(g) * a ** (b - 1), g => sum(g) * (a**b) * Math.log(a) ]]
const UnaryNeg = (a) => [-a, [ g => -sum(g) ]]
const UnaryPlus = (a) => [+a, [ g => +sum(g) ]]

Let's look at the function Add(a, b) which takes two numbers and returns their sum.

Let's call our gradient-aware version Add$(a, b)— a function with the same arguments (same number, same type). But rather than just returning the value, it returns a tuple where the first element is the same as the return value of Add(a, b) The second element is a list of gradient functions (length 2, one for each argument to the Add function).

Add(num a, num b) -> num
Add$(num a, num b) -> [ num, [ grad, grad ] ]
grad([num]) -> num

These gradient functions take in an array of numbers (we'll get into why this is an array later on) and return a number. These gradient functions essentially map the output of the function back to each of the arguments.

For instance, lets look at the Sum(vec a) -> num function, which (as its signature might suggest) maps a list of numbers to a single number. It has one gradient function (for its single argument) which maps a list of numbers (because the output is a number) to a vector (because the input is a vector).

Sum(vec a) -> num
Sum$(vec a) -> [ num, [grad] ]
grad([num]) -> vec

So why is it that gradient functions take as an argument, a list of the output type of a function? This is because we basically need to run the program backwards to compute the gradient of a program, and when we invert a function, we have to note that any given value may be used multiple times. Thus the grad function sees the values of all the functions using its result and decides what to do.

In the case of mathematical derivatives, we just sum up all the inputs (representing all the different times the output was used).

These function definitions originate almost verbatim from mathematics.

For Mul, consider f(x, y) = x * y
df/dx = y * df
df/dy = x * df
Thus Multiply$(x, y)'s gradients are:
[ y * sum(df), x * sum(df) ]

The autodiff algorithm needs to be able to trace the sequence of operations taken to compute the output of a function (and find a path back to the function arguments).

In order to do that, we transform the program in such a way this path is collected during runtime.

We'll be looking at the following function: which sets c its first argument n, and repeatedly multiplies c by n within a for loop.

function cube(n){
    let c = n;
    for(let i = 0; i < 3; i++){
        c *= n;
    }
    return c
}

First we apply a set of simple transformations to simplify the set of syntactic features we have to target. For instance, i++ is replaced with (i = i + 1) - 1 (as i++ returns the previous value of i before it is incremented). Any update-assignments, like c *= n are substituted with c = c * n. This gives us a function which looks like this:

function cube(n){
    let c = n;
    for(let i = 0; i < 3; (i = i + 1) - 1){
        c = c * n;
    }
    return c
}

The next step replaces all the differentiable infix operations with function calls. That is, every time the BinaryExpression x+y is used, substitute that with the function Add(x, y). In addition to the binary infix expressions, unary operations like negations are transformed, as well as property lookups and array indexing.

function cube(n){
    let c = n;
    for(let i = 0; i < 3; i = Call(Add, i, 1) - 1){
        c = Multiply(c, n);
    }
    return c
}

The next stage of the transformation pipeline is the most interesting. It augments all variables with little wrappers that can hold information about their provenance. However, it's careful to stick in Unwrap calls any time that illusion is at risk of being shattered. What's left is a program which behaves identically with the original program, except that the entire history of every expression is tracked.

Looking a bit more deeply there are two functions: Call and Unwrap. The Unwrap function is simpler, so we'll start with that. Given a wrapped value which holds the original value, it returns the original value. If the object that Unwrap is passed is not wrapped at all, it just returns that value.

The Call function takes in as its first argument a function which was meant to be called. That function's gradient-aware version is called with the unwrapped passed arguments. The result of that function is then wrapped, and the function's original arguments are logged.

const Call = (fn, ...args) => new Wrapped(...fn(...args.map(Unwrap)), args.map(k => k instanceof Wrapped ? k.index : -1));
const Unwrap = (x) => x instanceof Wrapped ? x.value : x

function cube(n){
    let c = n;
    for(let i = 0; Unwrap(i) < 3; Unwrap(i = Call(Add, i, 1)) - 1){
        c = Call(Multiply, c, n);
    }
    return c
}

In the process of running this transformed function, we assemble a list (called a Tape, or Weingert list) of operations that were applied on different variables/nodes. To compute the gradient of the output with respect to some variable, we need to step through this sequence of operations in reverse.

We create an empty array called Signals which tracks all the available gradient information. Within the gradient function, we're passed an incoming gradient (if we're trying to calculate the gradient of a function wrt a scalar output, the gradient is a single element list [1]). This is stored within Signals at the result.index'th slot.

Iterating through the list of operations in reverse, we skip the ones which are missing a Signals entry (these represent branches of the computation DAG that do not contribute to the final output and can be pruned). For each such node, we iterate through each of the defined arguments. If the Signals slot for that index is empty, allocate an empty array for that. Push the result to the target index of calling the respective grad method with the incoming signals to that slot.

Once that's done, the index of Signals representing the input argument nodes will contain a list of things which add up to the gradient of the function wrt the first argument.

return [ result.value, [(g = [1]) => {
    let Signals = [];
    Signals[result.index] = g;
    for(let i = result.index; i >= 0; i--){
        if(!Signals[i]) continue;
        let node = Nodes[i];
        for(let j = 0; j < node.args.length; j++){
            let arg = node.args[j];
            if(arg < 0) continue;
            if(!Signals[arg]) Signals[arg] = [];
            Signals[arg].push(node.grad[j](Signals[node.index]))
        }
    }
    return Signals[0]
}]]

Here's the whole thing:

const sum = x => x.reduce((a, b) => a + b, 0) // helper
const Add = (a, b) => [ a + b, [ g => sum(g), g => sum(g) ]]
const Subtract = (a, b) => [ a - b, [ g => sum(g), g => -sum(g) ]]
const Multiply = (a, b) => [ a * b, [ g => b * sum(g), g => a * sum(g) ]]
const Divide = (a, b) => [ a / b, [ g => sum(g)/b, g => -a * sum(g)/(b*b) ]]
const Power = (a, b) => [a**b, [ g => sum(g) * a ** (b - 1), g => sum(g) * (a**b) * Math.log(a) ]]
const UnaryNeg = (a) => [-a, [ g => -sum(g) ]]
const UnaryPlus = (a) => [+a, [ g => +sum(g) ]]

function cube(n_){
    let Nodes = []
    class Wrapped {
        constructor(value, grad=[], args=[]){
            this.value = value
            this.grad = grad;
            this.args = args
            this.index = Nodes.push(this) - 1;
        }
    }

    const Call = (fn, ...args) => new Wrapped(...fn(...args.map(Unwrap)), args.map(k => k instanceof Wrapped ? k.index : -1));
    const Unwrap = (x) => x instanceof Wrapped ? x.value : x

    let result = (function cube(n){
        let c = n;
        for(let i = 0; Unwrap(i) < 3; Unwrap(i = Call(Add, i, 1)) - 1){
            c = Call(Multiply, c, n);
        }
        return c
    })(new Wrapped(n_))
    
    return [ result.value, [(g = [1]) => {
        let Signals = [];
        Signals[result.index] = g;
        for(let i = result.index; i >= 0; i--){
            if(!Signals[i]) continue;
            let node = Nodes[i];
            for(let j = 0; j < node.args.length; j++){
                let arg = node.args[j];
                if(arg < 0) continue;
                if(!Signals[arg]) Signals[arg] = [];
                Signals[arg].push(node.grad[j](Signals[node.index]))
            }
        }
        return Signals[0]
    }]]
}

// evaluate cube at 3
let [ value, [gradn] ] = cube(3)
// how do we have to change the first argument to shift
// the output by a certain unit direction up
let dn = gradn([1]) 
// sum all the nudges together
console.log(sum(dn))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.