Skip to content

Instantly share code, notes, and snippets.

@cympfh
Created March 27, 2019 09:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cympfh/6ce21d90c2f36754116b9e588db7bee8 to your computer and use it in GitHub Desktop.
Save cympfh/6ce21d90c2f36754116b9e588db7bee8 to your computer and use it in GitHub Desktop.
extern crate tch; // https://github.com/LaurentMazare/tch-rs
use tch::{Tensor, no_grad};
fn sqrt(v: f64) -> f64 {
let mut x = Tensor::from(v / 2.0).set_requires_grad(true);
let lambda = 0.01;
for _ in 0..100 {
let y = &x * &x - v;
let loss = &y * &y;
x.zero_grad();
loss.backward();
println!(" y={} x={} dL/dx={}", y.double_value(&[]), x.double_value(&[]), x.grad().double_value(&[]));
no_grad(|| {
x -= x.grad() * lambda;
});
}
x.double_value(&[])
}
fn main() {
println!("sqrt 2 = {}", sqrt(2.0));
println!("sqrt 3 = {}", sqrt(3.0));
println!("sqrt 5 = {}", sqrt(5.0));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment