Skip to content

Instantly share code, notes, and snippets.

@mikecvet
Created October 13, 2023 23:03
Show Gist options
  • Save mikecvet/53a2748cf37a96a611d4b211112ca946 to your computer and use it in GitHub Desktop.
Save mikecvet/53a2748cf37a96a611d4b211112ca946 to your computer and use it in GitHub Desktop.
/// Runs forward propagation in this layer of the network. Flattens input, computes probabilities of
/// outputs based on the dot product with this layer's weights and softmax outputs.
pub fn
forward_propagation<'a> (&mut self, input: &'a Array3<f64>, ctx: &mut SoftmaxContext<'a>) -> Array1<f64>
{
let flattened: Array1<f64> = input.to_owned().into_shape((input.len(),)).unwrap();
let dot_result = flattened.dot(&self.weights).add(&self.bias);
let probabilities = softmax(&dot_result);
ctx.input = Some(input);
ctx.flattened = Some(flattened);
ctx.dot_result = Some(dot_result);
ctx.output = Some(probabilities.clone());
probabilities
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment