Skip to content

Instantly share code, notes, and snippets.

@tedsta
Created January 15, 2016 03:52
Show Gist options
  • Save tedsta/18dacbe58e677f45c89b to your computer and use it in GitHub Desktop.
Save tedsta/18dacbe58e677f45c89b to your computer and use it in GitHub Desktop.
Graph class in deeplearn-rs
pub struct Graph {
nodes: Vec<Node>,
node_ops: Vec<Box<Operation>>,
pub var_store: VarStore,
out_var_map: HashMap<VarIndex, (NodeIndex, usize)>, // Maps output variable to it's node and index within node
// Gradients on variables that are inputs to the graph - they have no corresponding node
in_var_grad: HashMap<VarIndex, OutGrad>,
}
impl Graph {
pub fn new() -> Self {
Graph {
nodes: vec![],
node_ops: vec![],
var_store: VarStore::new(),
out_var_map: HashMap::new(),
in_var_grad: HashMap::new(),
}
}
pub fn add_node(&mut self,
ctx: &matrix::Context,
op: Box<Operation>,
inputs: Vec<VarIndex>,
out_shapes: &[(u64, u64)])
-> NodeIndex {
let node_index = NodeIndex(self.nodes.len());
// Create output variables
let mut outputs = vec![];
for (i, &(rows, cols)) in out_shapes.iter().enumerate() {
let var_index = self.var_store.add(ClMatrix::new(ctx, rows as usize, cols as usize, ClMatrixMode::Mut));
outputs.push(var_index);
self.out_var_map.insert(var_index, (node_index, i));
}
// Create input gradient variables and set up gradient back flow
let mut in_grad = vec![];
for input in &inputs {
// Create input gradient variables
let (rows, cols) = (input.get(self).rows(), input.get(self).columns());
let var_index = self.var_store.add(ClMatrix::new(ctx, rows as usize, cols as usize, ClMatrixMode::Mut));
in_grad.push(var_index);
// Set up gradient back flow
match self.out_var_map.get(input).map(|x| *x) {
Some((in_node, out_index)) => {
self.nodes[in_node.0].out_grad[out_index].fork(ctx, &mut self.var_store, var_index);
},
None => {
// This input doesn't come from a node's output. It is an input to the graph.
self.in_var_grad.get_mut(input).unwrap()
.fork(ctx, &mut self.var_store, var_index);
},
}
}
// Create the node
self.nodes.push(Node { inputs: inputs,
outputs: outputs,
in_grad: in_grad,
out_grad: vec![OutGrad::new(); out_shapes.len()] });
// Add the corresponding node op
self.node_ops.push(op);
node_index
}
pub fn add_variable(&mut self, ctx: &matrix::Context, shape: (u64, u64)) -> VarIndex {
let v = self.var_store.add(ClMatrix::new(ctx, shape.0 as usize, shape.1 as usize, ClMatrixMode::Mut));
self.in_var_grad.insert(v, OutGrad::new());
v
}
pub fn get_input_gradient<'a>(&'a self, v: VarIndex) -> Option<VarIndex> {
self.in_var_grad.get(&v).and_then(|x| x.try_gradient())
}
pub fn add_gradient(&mut self, ctx: &matrix::Context, n: NodeIndex, out_index: usize) -> VarIndex {
let (rows, cols) = {
let grad = n.get(self).outputs[out_index];
let grad = grad.get(self);
(grad.rows(), grad.columns())
};
let v = self.var_store.add(ClMatrix::new(ctx, rows as usize, cols as usize, ClMatrixMode::Mut));
self.nodes[n.0].out_grad[out_index].fork(ctx, &mut self.var_store, v);
v
}
pub fn run(&mut self, ctx: &matrix::Context) {
// Forward pass
//
// NOTE: We just execute the nodes in order. We can do this because of the way the graph is
// built. When a user wants to add a node, he/she must also supply the inputs. This means
// any dependencies must already be added before the node can be added. Therefore, we can
// assert that all dependents come after their dependencies in the `self.nodes` array.
for (node, op) in self.nodes.iter_mut().zip(&mut self.node_ops) {
op.forward(ctx, &mut self.var_store, node);
}
// Backward pass
for (node, op) in self.nodes.iter_mut().rev().zip(self.node_ops.iter_mut().rev()) {
// Sum the gradients on each output if there are multiple gradients
for out_grad in &node.out_grad {
out_grad.maybe_sum(ctx, &mut self.var_store);
}
op.backward(ctx, &mut self.var_store, node);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment