This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Generalization of https://stackoverflow.com/a/32175958 | |
// With IndexSequence from https://stackoverflow.com/a/49672613 | |
#include <array> | |
template <std::size_t ...> | |
struct IndexSequence { }; | |
template <std::size_t N, std::size_t ... Next> | |
struct IndexSequenceHelper : public IndexSequenceHelper<N-1U, N-1U, Next...> { }; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate ocl; | |
extern crate rspirv; | |
extern crate spirv_headers as spirv; | |
use ocl::{Platform, Device, Context, Queue, Buffer, Program, Kernel, Event, EventList}; | |
use rspirv::binary::Disassemble; | |
pub fn find_platform() -> Option<Platform> { | |
let platform_name = "Experimental OpenCL 2.1 CPU Only Platform"; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
println!("#######################################"); | |
println!("Validating"); | |
num_correct = 0; | |
for epoch in 0..val_images.len() { | |
// Upload training data | |
input.write(graph, &val_images[epoch]); | |
// Run the graph | |
graph.forward(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// We apply a gradient of -0.001 to the loss function | |
let loss_d_cpu = Array::new(vec![batch_size, 10], -0.001); | |
loss_d.write(graph, &loss_d_cpu); | |
let mut loss_out_cpu = Array::new(vec![batch_size, 10], 0.0); | |
let mut l2_out_cpu = Array::new(vec![batch_size, 10], 0.0); | |
let mut l2_out_d_cpu = Array::new(vec![batch_size, 10], 0.0); | |
let mut predictions = Array::new(vec![batch_size], 0usize); | |
let mut num_correct = 0; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub fn dense_biased<WI: Initializer, BI: Initializer>(graph: &mut Graph, | |
input: VarIndex, | |
layer_size: usize, | |
w_init: WI, | |
b_init: BI) | |
-> (VarIndex, VarIndex, VarIndex) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Weights for layer 1: [input_size x layer_size] | |
let weights = graph.add_variable(vec![input_size, layer_size], true, w_init); | |
// Use matrix multiplication to do a fully connected layer | |
let mat_mul = graph.add_node(MatMul(input, weights)); | |
let mat_mul_out = mat_mul.get(&graph).outputs[0]; | |
// Biases, one for each neuron in layer | |
let bias = graph.add_variable(vec![1, layer_size], true, b_init); | |
// Add the biases to the matrix multiplication output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub struct MatMul(pub VarIndex, pub VarIndex); | |
impl OpBuilder for MatMul { | |
type Op = MatMulImpl; | |
fn build(&self, ctx: &ga::Context, v: &VarStore) | |
-> Result<OpDescriptor<MatMulImpl>, String> { | |
let a = &v.get(self.0); | |
let b = &v.get(self.1); | |
if a.shape()[1] != b.shape()[0] { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub trait Operation : 'static { | |
fn forward(&mut self, &ga::Context, &VarStore, &Node); | |
fn backward(&mut self, &ga::Context, &VarStore, &Node); | |
} | |
pub trait OpBuilder { | |
type Op: Operation; | |
fn build(&self, ctx: &ga::Context, v: &VarStore) | |
-> Result<OpDescriptor<Self::Op>, String>; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
let mat_mul = graph.add_node(MatMul(input, weights)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
let l1_mat_mul = graph.add_node(&ctx, | |
Box::new(MatMul::new(&ctx, (1, 2), (2, 3))), | |
vec![a, l1_w], | |
&[(1, 3)]); |
NewerOlder