Last active
May 24, 2018 12:15
-
-
Save Nyrox/4de23b91c8e7c36bbd72132047ef7d94 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate rand; | |
use rand::Rng; | |
const TRAINING_SET_SIZE: u64 = 1024; | |
fn transfer_derivative(output: f64) -> f64 { | |
return output * (1.0 - output); | |
} | |
#[derive(Clone, Debug)] | |
struct Neuron { | |
weights: Vec<f64>, | |
value: f64, | |
delta: f64, | |
} | |
impl Neuron { | |
pub fn new(weight_count: usize) -> Neuron { | |
Neuron { | |
value: 0.0, | |
delta: 0.0, | |
weights: (0..weight_count).map(|_| { | |
rand::random() | |
}).collect() | |
} | |
} | |
} | |
#[derive(Clone, Debug)] | |
struct Layer { | |
neurons: Vec<Neuron>, | |
id: i64 | |
} | |
impl Layer { | |
pub fn new(id: i64, neuron_count: usize, weight_count: usize) -> Layer { | |
Layer { | |
id, | |
neurons: (0..neuron_count).map(|_| { | |
Neuron::new(weight_count) | |
}).collect() | |
} | |
} | |
pub fn run(&mut self, inputs: &Vec<f64>) -> Vec<f64> { | |
if inputs.len() != self.neurons[0].weights.len() { panic!(format!("Layer({}), {:?}, {:?}", self.id, inputs.len(), self.neurons[0].weights.len())) }; | |
let mut outputs = vec![]; | |
for mut neuron in self.neurons.iter_mut() { | |
let mut result = 0.0; | |
for i in 0..inputs.len() { | |
result += inputs[i] * neuron.weights[i]; | |
} | |
result = 1.0 / (1.0 + (-result).exp()); | |
neuron.value = result; | |
outputs.push(result); | |
} | |
outputs | |
} | |
} | |
#[derive(Debug)] | |
struct Network{ | |
layers: Vec<Layer>, | |
} | |
impl Network { | |
pub fn new(input_count: usize, output_count: usize, (hidden_count, hidden_size): (usize, usize)) -> Network { | |
let mut layers = vec![]; | |
layers.append( | |
&mut (0..hidden_count).map(|i| { | |
Layer::new(i as i64, hidden_size, match i { | |
_ if i == 0 => input_count, | |
_ => hidden_size | |
}) | |
}).collect() | |
); | |
layers.push(Layer::new(-10, output_count, hidden_size)); | |
Network { | |
layers: layers | |
} | |
} | |
pub fn run(&mut self, input_values: &Vec<f64>) -> Vec<f64> { | |
let mut inputs = input_values.clone(); | |
for mut layer in self.layers.iter_mut() { | |
inputs = layer.run(&inputs); | |
} | |
return inputs; | |
} | |
pub fn train(&mut self, inputs: Vec<f64>, expected: Vec<f64>) { | |
// Backpropagate error values | |
for (i, layer) in self.layers.clone().iter().enumerate().rev() { | |
let mut errors = vec![]; | |
if i != self.layers.len() - 1 { | |
for (j, _) in layer.neurons.iter().enumerate() { | |
let mut error = 0.0; | |
for neuron in self.layers[i + 1].neurons.iter() { | |
error += neuron.weights[j] * neuron.delta; | |
} | |
errors.push(error); | |
} | |
} | |
else { | |
for (j, neuron) in layer.neurons.iter().enumerate() { | |
errors.push(expected[j] - neuron.value); | |
} | |
} | |
for (j, neuron) in self.layers[i].neurons.iter_mut().enumerate() { | |
neuron.delta = errors[j] * transfer_derivative(neuron.value); | |
} | |
} | |
// Apply new weights | |
const LEARNING_RATE: f64 = 0.25; | |
for i in 0..self.layers.len() { | |
let mut _inputs = inputs.clone(); | |
if i != 0 { | |
_inputs = vec![]; | |
for neuron in self.layers[i - 1].neurons.iter() { | |
_inputs.push(neuron.value); | |
} | |
} | |
for neuron in self.layers[i].neurons.iter_mut() { | |
for (j, input) in _inputs.iter().enumerate() { | |
neuron.weights[j] += LEARNING_RATE * neuron.delta * input; | |
} | |
} | |
} | |
} | |
} | |
fn main() { | |
::std::env::set_var("RUST_BACKTRACE", "1"); | |
let mut rng = rand::thread_rng(); | |
let mut network = Network::new(2, 2, (3, 4)); | |
let mut generate_set_identity_vector = |size| -> Vec<(Vec<f64>, Vec<f64>)> { | |
(0..size).map(|_| { | |
let x = rng.gen(); | |
let y = rng.gen(); | |
(vec![x, y], vec![x, y]) | |
}).collect() | |
}; | |
let training_set = generate_set_identity_vector(200000); | |
println!("Running training set..."); | |
for entry in training_set { | |
network.run(&entry.0); | |
network.train(entry.0, entry.1); | |
} | |
println!("Finished running training set."); | |
let test_set = generate_set(500); | |
println!("Running test set..."); | |
for entry in test_set { | |
println!("Result: {:?}, Expected: {:?}", network.run(&entry.0), entry.1); | |
} | |
println!("Finished running test set."); | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment