Created
September 18, 2016 17:07
-
-
Save JosephCatrambone/33d752bd3ad2dd7901be3f5f5366783a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# JAD: Joseph's Automatic Differentiation | |
from collections import deque | |
class Graph(object): | |
def __init__(self): | |
self.names = list() | |
self.operations = list() | |
self.derivatives = list() # A list of LISTS, where each item is the gradient with respect to that argument. | |
self.node_inputs = list() # A list of the indices of the input nodes. | |
self.shapes = list() | |
self.graph_inputs = list() | |
self.forward = list() # Cleared on forward pass. | |
self.adjoint = list() # Cleared on reverse pass. | |
def get_output(self, input_set, node=-1): | |
self.forward = list() | |
for i, op in enumerate(self.operations): | |
self.forward.append(op(input_set)) | |
return self.forward[node] | |
def get_gradient(self, input_set, node, forward_data=None): | |
if forward_data is not None: | |
self.forward = forward_data | |
else: | |
self.forward = list() | |
for i, op in enumerate(self.operations): | |
self.forward.append(op(input_set)) | |
# Initialize adjoints to 0 except our target, which is 1. | |
self.adjoint = [0.0]*len(self.forward) | |
self.adjoint[node] = 1.0 | |
gradient_stack = deque() | |
for input_node in self.node_inputs[node]: | |
gradient_stack.append((input_node, node)) # Keep pairs of target/parent. | |
while gradient_stack: # While not empty. | |
current_node, parent_node = gradient_stack.popleft() | |
for dop in self.derivatives[current_node]: | |
self.adjoint[current_node] += self.adjoint[parent_node]*dop(input_set) | |
for input_arg in self.node_inputs[current_node]: | |
gradient_stack.append((input_arg, current_node)) | |
return self.adjoint | |
def get_shape(self, node): | |
return self.shapes[node] | |
def add_input(self, name, shape): | |
index = len(self.names) | |
self.names.append(name) | |
self.operations.append(lambda inputs: inputs[name]) | |
self.derivatives.append([lambda inputs: 1]) | |
self.node_inputs.append([]) | |
self.graph_inputs.append(index) | |
self.shapes.append(shape) | |
return index | |
def add_add(self, name, left, right): | |
index = len(self.names) | |
self.names.append(name) | |
self.operations.append(lambda inputs: self.forward[left] + self.forward[right]) | |
self.derivatives.append([lambda inputs: 1, lambda inputs: 1]) # d/dx a + b = 1 + 0 or 0 + 1 | |
self.node_inputs.append([left, right]) | |
self.shapes.append(self.get_shape(left)) | |
return index | |
def add_multiply(self, name, left, right): | |
index = len(self.names) | |
self.names.append(name) | |
self.operations.append(lambda inputs: self.forward[left] * self.forward[right]) | |
self.derivatives.append([lambda inputs: self.forward[right], lambda inputs: self.forward[left]]) | |
self.node_inputs.append([left, right]) | |
self.shapes.append(self.get_shape(left)) | |
return index | |
if __name__=="__main__": | |
g = Graph() | |
x = g.add_input("x", (1, 1)) | |
y = g.add_input("y", (1, 1)) | |
a = g.add_add("a", x, y) | |
b = g.add_multiply("b", a, x) | |
input_map = {'x': 2, 'y': 3} | |
print(g.get_output(input_map)) # 10 | |
print(g.get_gradient(input_map, b)) # 3, 2, 2, 1. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment