This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ultralytics import YOLO | |
import torch | |
def main(): | |
model = YOLO("yolov8n.pt") | |
example_input = torch.ones((1, 3, 640, 640)) | |
exported_program = torch.export.export(model, (example_input,)) | |
print(exported_program) | |
if __name__ == '__main__': |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Perform backward pass through the linear layer. | |
fn backward(&mut self, dL_dy: &Array2<f64>) -> Array2<f64> { | |
// Calculate the gradient of the loss with respect to W | |
let dL_dW = dL_dy.t().dot( | |
&self | |
.dy_dW | |
.as_ref() | |
.expect("Need to call forward() first.") | |
.view(), | |
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Perform forward pass through the linear layer. | |
fn forward(&mut self, x: &Array2<f64>) -> Array2<f64> { | |
// Store the input gradient for later use in backward pass | |
self.dy_dW = Some(x.to_owned()); | |
self.get_output(x) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Get the output of the linear layer. | |
fn get_output(&self, x: &Array2<f64>) -> Array2<f64> { | |
// Formula: (W * x^T + b)^T | |
(self.W.dot(&x.t()) + self.b.clone()).t().to_owned() | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// LinearLayer represents a linear layer in a neural network. | |
pub struct LinearLayer { | |
pub W: Array2<f64>, | |
pub b: Array2<f64>, | |
// Gradient of the loss | |
pub dL_dW: Option<Array2<f64>>, | |
pub dL_db: Option<Array2<f64>>, | |
// Gradient of the output | |
pub dy_dW: Option<Array2<f64>>, | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Perform backward pass through the ReLU activation function. | |
fn backward(&mut self, dL_dy: &Array2<f64>) -> Array2<f64> { | |
// Calculate dL/dx = dL/dy * dy/dx | |
dL_dy * &self.dy_dx.clone().expect("Need to call forward() first.") | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Perform forward pass through the ReLU activation function. | |
fn forward(&mut self, x: &Array2<f64>) -> Array2<f64> { | |
// Calculate dy/dx for backpropagation | |
let dy_dx = x.mapv(|val| if val > 0.0 { 1.0 } else { 0.0 }); | |
// Store dy/dx for later use in backward pass | |
self.dy_dx = Some(dy_dx); | |
// Get the output of the ReLU activation function | |
self.get_output(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Get the output of the ReLU activation function. | |
fn get_output(&self, x: &Array2<f64>) -> Array2<f64> { | |
// Apply ReLU function element-wise to input x | |
x.mapv(|val| f64::max(val, 0.0)) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// ReLU represents the Rectified Linear Unit activation function. | |
pub struct ReLU { | |
pub dy_dx: Option<Array2<f64>>, | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Perform backward pass through the MSE loss function. | |
pub fn backward(&self) -> Array2<f64> { | |
// Return dL/dx obtained from the `self.dL_dx` field | |
self.dL_dx.clone().expect("Need to call forward() first.") | |
} |
NewerOlder