Skip to content

Instantly share code, notes, and snippets.

@emchristiansen
Created July 27, 2023 15:19
Show Gist options
  • Save emchristiansen/e3070357de20bfd97fe2e49933829d54 to your computer and use it in GitHub Desktop.
Save emchristiansen/e3070357de20bfd97fe2e49933829d54 to your computer and use it in GitHub Desktop.
ArcTape usage example with Module API
use std::sync::Arc;
use std::sync::Mutex;
use color_eyre::Result;
use dfdx::prelude::BuildOnDevice;
use dfdx::prelude::Linear;
use dfdx::prelude::Module;
use dfdx::prelude::ZeroGrads;
use dfdx::shapes::Rank0;
use dfdx::shapes::Rank2;
use dfdx::tensor::AsArray;
use dfdx::tensor::Cpu;
use dfdx::tensor::OwnedTape;
use dfdx::tensor::PutTape;
use dfdx::tensor::SampleTensor;
use dfdx::tensor::SplitTape;
use dfdx::tensor::Tape;
use dfdx::tensor::Tensor;
use dfdx::tensor::TensorFrom;
use dfdx::tensor::Trace;
use dfdx::tensor::WithEmptyTape;
use dfdx::tensor_ops::Backward;
use dfdx::tensor_ops::SumTo;
use dfdx::tensor_ops::TryStack;
fn main() -> Result<()>
{
let dev = Cpu::default();
type Model = (Linear<4, 2>,);
let model = Model::build_on_device(&dev);
let x: Tensor<Rank2<3, 4>, f32, Cpu> = dev.sample_normal();
let grads = model.alloc_grads();
let tape = Arc::new(Mutex::new(OwnedTape::<f32, Cpu>::from(grads)));
let out = model.forward(
x.clone()
.put_tape(tape.clone()),
);
let grad = out
.sum()
.backward();
dbg!(grad
.get(
&model
.0
.weight
)
.array());
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment