Skip to content

Instantly share code, notes, and snippets.

@bartvm
Last active May 19, 2017 21:35
Show Gist options
  • Save bartvm/404f769b98bda7837dafe5fb730d17b0 to your computer and use it in GitHub Desktop.
Save bartvm/404f769b98bda7837dafe5fb730d17b0 to your computer and use it in GitHub Desktop.
import ast
import collections
import inspect
import numbers
import textwrap
import numpy
PUSH = ast.Attribute(value=ast.Name(id='_stack', ctx=ast.Load()),
attr='push', ctx=ast.Load())
POP = ast.Attribute(value=ast.Name(id='_stack', ctx=ast.Load()),
attr='pop', ctx=ast.Load())
def parse_function(fn):
return ast.parse(textwrap.dedent(inspect.getsource(fn)))
class NodeReverse(object):
"""Generate a primal and adjoint for a given AST tree.
Notes
-----
In principle, this class simply walks the AST recursively and for each node
returns a new primal and an adjoint.
A limited amount of communication happens through the state of the
class. Assign statements set `current_target` so that the adjoint of the
right hand side knows what gradient to read. On the other hand, right
hand side expressions set `current_partials` to tell assignment
statements what variables the partials were written to.
"""
def visit(self, node):
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
@staticmethod
def create_grad(node):
"""Given a variable, create variable name for the gradient.
WARNING: This returns an invalid node, with the `ctx` attribute
missing. It is assumed that this attribute is filled in later (e.g.
by the `replace` function).
"""
if not isinstance(node, ast.Name):
raise TypeError
return ast.Name(id='d' + node.id)
@staticmethod
def create_var(id_):
"""Method to create a named variable. Used for temporaries."""
return ast.Name(id=id_, ctx=ast.Load())
def visit_FunctionDef(self, node):
# TODO Change function signatures to receive stack
# TODO Change adjoint signature to take output of primal and its initial gradient
pass
def visit_statements(self, nodes):
"""Generate the adjoint of a series of statements."""
primals, adjoints = [], collections.deque()
for node in nodes:
primal, adjoint = self.visit(node)
primals.extend(primal)
adjoints.extendleft(adjoint[::-1])
return primals, list(adjoints)
def visit_For(self, node):
assert not node.orelse
primal_body, adjoint_body = self.visit_statements(node.body)
def primal_template(body, iter_, target, push):
i = 0
for target in iter_:
i += 1
body
push(i)
primal = replace(primal_template, body=primal_body, push=PUSH,
target=node.target, iter_=node.iter)
def adjoint_template(body, pop):
i = pop()
for _ in range(i):
body
adjoint = replace(adjoint_template, body=adjoint_body, pop=POP)
return primal, adjoint
def visit_BinOp(self, node):
adjoint_templates = {}
def adjoint_Mult_template(x, y, dx, dy, dz):
dx = dz * y
dy = dz * x
adjoint_templates[ast.Mult] = adjoint_Mult_template
def adjoint_Add_template(x, y, dx, dy, dz):
dx = dz
dy = dz
adjoint_templates[ast.Add] = adjoint_Add_template
def adjoint_Div_template(x, y, dx, dy, dz):
dx = dz / y
dy = -dz * x / y ** 2
adjoint_templates[ast.Div] = adjoint_Div_template
op = type(node.op)
if op not in adjoint_templates:
raise ValueError("unknown binary operator")
self.current_partials = {
node.left: self.create_var('__dx'),
node.right: self.create_var('__dy')
}
return node, replace(
adjoint_templates[op],
x=node.left, y=node.right,
dx=self.current_partials[node.left],
dy=self.current_partials[node.right],
dz=self.create_grad(self.current_target))
def visit_Assign(self, node):
if len(node.targets) != 1:
raise ValueError
if isinstance(node.targets[0], ast.Tuple):
if not isinstance(node.value, ast.Name):
raise ValueError("can only unpack variables")
# TODO Pack the gradients into a tuple
raise ValueError("no support for tuple assignments")
if not isinstance(node.targets[0], ast.Name):
raise ValueError("can only assign to names")
# Extract the target and store it in the state so that the
# right hand side templates can use it
target = node.targets[0]
self.current_target = target
primal_rhs, adjoint_rhs = self.visit(node.value)
# NOTE We simplify things here by EAFP. Ideally each variable that is
# pushed at any point should be set to None at the beginning
def primal_template(target, primal_rhs, push):
try:
push(target)
except NameError:
push(None)
target = primal_rhs
primal = replace(primal_template, target=target,
primal_rhs=primal_rhs, push=PUSH)
# NOTE For each partial gradient from the rhs we want to accumulate
# it into the existing gradient; this is the template for that
# NOTE EAFP approach again; gradients should be initialized beforehand
def accumulate_template(in_grad, partial_grad):
try:
in_grad = add_grad(in_grad, partial_grad)
except NameError:
in_grad = partial_grad
gradient_accumulation = []
for partial in self.current_partials:
gradient_accumulation.extend(replace(
accumulate_template, in_grad=self.create_grad(partial),
partial_grad=self.current_partials[partial]))
# The final adjoint restores the input (pop), stores the partials
# in temporary variables, resets the gradient w.r.t. output,
# and finally updates the gradients
def adjoint_template(target, adjoint_rhs, target_grad,
gradient_accumulation, pop):
target = pop()
adjoint_rhs
target_grad = 0
gradient_accumulation
adjoint = replace(adjoint_template, target=target,
adjoint_rhs=adjoint_rhs,
gradient_accumulation=gradient_accumulation,
target_grad=self.create_grad(target), pop=POP)
# Reset the state
self.current_target = None
self.current_partials = None
return primal, adjoint
def generic_visit(self, node):
raise ValueError("unknown node type")
class ReplaceTransformer(ast.NodeTransformer):
"""Replace variables with AST nodes"""
def __init__(self, replacements):
self.replacements = replacements
def visit_Name(self, node):
replacement_node = self.replacements.get(node.id, node)
# Use the replacement node in the same context as the placeholder
if isinstance(replacement_node, ast.AST) and \
'ctx' in replacement_node._fields:
replacement_node.ctx = node.ctx
return replacement_node
def replace(fn, **replacements):
"""Replace placeholders in a Python template (quote).
One special thing happens: If a replacement node has a ctx attribute, it
is made to match the ctx attribute of the variable it is replacing.
Parameters
----------
fn : function
A function used as a metaprogramming template.
replacements : dict
A mapping from the variable names of the function's arguments to (lists
of) AST nodes that these variables will be replaced with wherever they
appear in the function body. A replacement can be a list, in which case
it will be merged into the list of statements containing the node.
Returns
-------
body : list
A list of statements in the form of AST nodes.
"""
tree = parse_function(fn).body[0]
if replacements.keys() != set(arg.arg for arg in tree.args.args):
raise ValueError("too many or few replacements")
tree = ReplaceTransformer(replacements).visit(tree)
return tree.body
def add_grad(left, right):
"""Recursively add the gradient of e.g. tuples."""
# If the gradient is undefined, then we simply return the rhs
# NOTE This is more efficient than initializing empty gradients and
# adding to them, since we could be adding to large matrix of zeros then
if left is None:
return right
assert right is not None
if type(left) != type(right):
raise TypeError("incompatible gradients")
if isinstance(left, (numpy.ndarray, numbers.Number)):
return left + right
if isinstance(left, tuple):
return tuple(lelem + relem for lelem, relem in zip(left, right))
raise TypeError("unknown gradient type")
def f(x):
y = x * x
def g(x):
for i in range(10):
y = x * x
if __name__ == "__main__":
body = parse_function(g).body[0].body
primal, adjoint = NodeReverse().visit_statements(body)
import astor
print("PRIMAL")
print(astor.to_source(ast.Module(body=primal)))
print("ADJOINT")
print(astor.to_source(ast.Module(body=adjoint)))
# PRIMAL
# i = 0
# for i in range(10):
# i += 1
# try:
# _stack.push(y)
# except NameError:
# _stack.push(None)
# y = x * x
# _stack.push(i)
# ADJOINT
# i = _stack.pop()
# for _ in range(i):
# y = _stack.pop()
# __dx = dy * x
# __dy = dy * x
# dy = 0
# try:
# dx = add_grad(dx, __dx)
# except NameError:
# dx = __dx
# try:
# dx = add_grad(dx, __dy)
# except NameError:
# dx = __dy
@alexbw
Copy link

alexbw commented May 19, 2017

def adjoint_Mult_template(x,y,z):
    d[x] = d[z] * y
    d[y] = d[z] * x

would also be curious to see how you specialize Call(), whereas BinOp has stereotyped args, Call grads will have variable args. So you'll have to break apart the args and kwargs coming from the primal Call node and feed them in somehow (maybe a place for *args and **kwargs to the quote)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment