Created
January 15, 2016 03:52
-
-
Save tedsta/18dacbe58e677f45c89b to your computer and use it in GitHub Desktop.
Graph class in deeplearn-rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub struct Graph { | |
nodes: Vec<Node>, | |
node_ops: Vec<Box<Operation>>, | |
pub var_store: VarStore, | |
out_var_map: HashMap<VarIndex, (NodeIndex, usize)>, // Maps output variable to it's node and index within node | |
// Gradients on variables that are inputs to the graph - they have no corresponding node | |
in_var_grad: HashMap<VarIndex, OutGrad>, | |
} | |
impl Graph { | |
pub fn new() -> Self { | |
Graph { | |
nodes: vec![], | |
node_ops: vec![], | |
var_store: VarStore::new(), | |
out_var_map: HashMap::new(), | |
in_var_grad: HashMap::new(), | |
} | |
} | |
pub fn add_node(&mut self, | |
ctx: &matrix::Context, | |
op: Box<Operation>, | |
inputs: Vec<VarIndex>, | |
out_shapes: &[(u64, u64)]) | |
-> NodeIndex { | |
let node_index = NodeIndex(self.nodes.len()); | |
// Create output variables | |
let mut outputs = vec![]; | |
for (i, &(rows, cols)) in out_shapes.iter().enumerate() { | |
let var_index = self.var_store.add(ClMatrix::new(ctx, rows as usize, cols as usize, ClMatrixMode::Mut)); | |
outputs.push(var_index); | |
self.out_var_map.insert(var_index, (node_index, i)); | |
} | |
// Create input gradient variables and set up gradient back flow | |
let mut in_grad = vec![]; | |
for input in &inputs { | |
// Create input gradient variables | |
let (rows, cols) = (input.get(self).rows(), input.get(self).columns()); | |
let var_index = self.var_store.add(ClMatrix::new(ctx, rows as usize, cols as usize, ClMatrixMode::Mut)); | |
in_grad.push(var_index); | |
// Set up gradient back flow | |
match self.out_var_map.get(input).map(|x| *x) { | |
Some((in_node, out_index)) => { | |
self.nodes[in_node.0].out_grad[out_index].fork(ctx, &mut self.var_store, var_index); | |
}, | |
None => { | |
// This input doesn't come from a node's output. It is an input to the graph. | |
self.in_var_grad.get_mut(input).unwrap() | |
.fork(ctx, &mut self.var_store, var_index); | |
}, | |
} | |
} | |
// Create the node | |
self.nodes.push(Node { inputs: inputs, | |
outputs: outputs, | |
in_grad: in_grad, | |
out_grad: vec![OutGrad::new(); out_shapes.len()] }); | |
// Add the corresponding node op | |
self.node_ops.push(op); | |
node_index | |
} | |
pub fn add_variable(&mut self, ctx: &matrix::Context, shape: (u64, u64)) -> VarIndex { | |
let v = self.var_store.add(ClMatrix::new(ctx, shape.0 as usize, shape.1 as usize, ClMatrixMode::Mut)); | |
self.in_var_grad.insert(v, OutGrad::new()); | |
v | |
} | |
pub fn get_input_gradient<'a>(&'a self, v: VarIndex) -> Option<VarIndex> { | |
self.in_var_grad.get(&v).and_then(|x| x.try_gradient()) | |
} | |
pub fn add_gradient(&mut self, ctx: &matrix::Context, n: NodeIndex, out_index: usize) -> VarIndex { | |
let (rows, cols) = { | |
let grad = n.get(self).outputs[out_index]; | |
let grad = grad.get(self); | |
(grad.rows(), grad.columns()) | |
}; | |
let v = self.var_store.add(ClMatrix::new(ctx, rows as usize, cols as usize, ClMatrixMode::Mut)); | |
self.nodes[n.0].out_grad[out_index].fork(ctx, &mut self.var_store, v); | |
v | |
} | |
pub fn run(&mut self, ctx: &matrix::Context) { | |
// Forward pass | |
// | |
// NOTE: We just execute the nodes in order. We can do this because of the way the graph is | |
// built. When a user wants to add a node, he/she must also supply the inputs. This means | |
// any dependencies must already be added before the node can be added. Therefore, we can | |
// assert that all dependents come after their dependencies in the `self.nodes` array. | |
for (node, op) in self.nodes.iter_mut().zip(&mut self.node_ops) { | |
op.forward(ctx, &mut self.var_store, node); | |
} | |
// Backward pass | |
for (node, op) in self.nodes.iter_mut().rev().zip(self.node_ops.iter_mut().rev()) { | |
// Sum the gradients on each output if there are multiple gradients | |
for out_grad in &node.out_grad { | |
out_grad.maybe_sum(ctx, &mut self.var_store); | |
} | |
op.backward(ctx, &mut self.var_store, node); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment