Instantly share code, notes, and snippets.

Embed
What would you like to do?
#!/usr/bin/env python
# JAD: Joseph's Automatic Differentiation
from collections import deque
class Graph(object):
def __init__(self):
self.names = list()
self.operations = list()
self.derivatives = list() # A list of LISTS, where each item is the gradient with respect to that argument.
self.node_inputs = list() # A list of the indices of the input nodes.
self.shapes = list()
self.graph_inputs = list()
self.forward = list() # Cleared on forward pass.
self.adjoint = list() # Cleared on reverse pass.
def get_output(self, input_set, node=-1):
self.forward = list()
for i, op in enumerate(self.operations):
self.forward.append(op(input_set))
return self.forward[node]
def get_gradient(self, input_set, node, forward_data=None):
if forward_data is not None:
self.forward = forward_data
else:
self.forward = list()
for i, op in enumerate(self.operations):
self.forward.append(op(input_set))
# Initialize adjoints to 0 except our target, which is 1.
self.adjoint = [0.0]*len(self.forward)
self.adjoint[node] = 1.0
gradient_stack = deque()
for input_node in self.node_inputs[node]:
gradient_stack.append((input_node, node)) # Keep pairs of target/parent.
while gradient_stack: # While not empty.
current_node, parent_node = gradient_stack.popleft()
for dop in self.derivatives[current_node]:
self.adjoint[current_node] += self.adjoint[parent_node]*dop(input_set)
for input_arg in self.node_inputs[current_node]:
gradient_stack.append((input_arg, current_node))
return self.adjoint
def get_shape(self, node):
return self.shapes[node]
def add_input(self, name, shape):
index = len(self.names)
self.names.append(name)
self.operations.append(lambda inputs: inputs[name])
self.derivatives.append([lambda inputs: 1])
self.node_inputs.append([])
self.graph_inputs.append(index)
self.shapes.append(shape)
return index
def add_add(self, name, left, right):
index = len(self.names)
self.names.append(name)
self.operations.append(lambda inputs: self.forward[left] + self.forward[right])
self.derivatives.append([lambda inputs: 1, lambda inputs: 1]) # d/dx a + b = 1 + 0 or 0 + 1
self.node_inputs.append([left, right])
self.shapes.append(self.get_shape(left))
return index
def add_multiply(self, name, left, right):
index = len(self.names)
self.names.append(name)
self.operations.append(lambda inputs: self.forward[left] * self.forward[right])
self.derivatives.append([lambda inputs: self.forward[right], lambda inputs: self.forward[left]])
self.node_inputs.append([left, right])
self.shapes.append(self.get_shape(left))
return index
if __name__=="__main__":
g = Graph()
x = g.add_input("x", (1, 1))
y = g.add_input("y", (1, 1))
a = g.add_add("a", x, y)
b = g.add_multiply("b", a, x)
input_map = {'x': 2, 'y': 3}
print(g.get_output(input_map)) # 10
print(g.get_gradient(input_map, b)) # 3, 2, 2, 1.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment