Skip to content

Instantly share code, notes, and snippets.

@hietalajulius
Created July 28, 2023 12:46
Show Gist options
  • Save hietalajulius/f17b76cd66d6ccab5d7e309920c4e506 to your computer and use it in GitHub Desktop.
Save hietalajulius/f17b76cd66d6ccab5d7e309920c4e506 to your computer and use it in GitHub Desktop.
/// Perform backward pass through the linear layer.
fn backward(&mut self, dL_dy: &Array2<f64>) -> Array2<f64> {
// Calculate the gradient of the loss with respect to W
let dL_dW = dL_dy.t().dot(
&self
.dy_dW
.as_ref()
.expect("Need to call forward() first.")
.view(),
);
// Calculate the gradient of the loss with respect to b
let dL_db = dL_dy.t().dot(&Array2::ones((dL_dy.shape()[0], 1)));
// Store the gradients for later use
self.dL_dW = Some(dL_dW);
self.dL_db = Some(dL_db.to_owned());
// Calculate the gradient of the loss with respect to the input
dL_dy.dot(&self.W)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment