Skip to content

Instantly share code, notes, and snippets.

@emchristiansen
Created July 27, 2023 15:18
Show Gist options
  • Save emchristiansen/cb490c3679939c2064a99e64992c34f1 to your computer and use it in GitHub Desktop.
Save emchristiansen/cb490c3679939c2064a99e64992c34f1 to your computer and use it in GitHub Desktop.
ArcTape usage example
use std::sync::Arc;
use std::sync::Mutex;
use color_eyre::Result;
use dfdx::shapes::Rank0;
use dfdx::tensor::AsArray;
use dfdx::tensor::Cpu;
use dfdx::tensor::OwnedTape;
use dfdx::tensor::PutTape;
use dfdx::tensor::SplitTape;
use dfdx::tensor::Tape;
use dfdx::tensor::Tensor;
use dfdx::tensor::TensorFrom;
use dfdx::tensor::Trace;
use dfdx::tensor::WithEmptyTape;
use dfdx::tensor_ops::Backward;
use dfdx::tensor_ops::TryStack;
type NestedStruct<T> = Tensor<Rank0, f32, Cpu, T>;
fn forward<T>(
x: NestedStruct<T>,
y: NestedStruct<T>,
) -> (
NestedStruct<T>,
NestedStruct<T>,
)
where
T: Tape<f32, Cpu> + Clone,
{
let prod_1 = x.clone() * y.clone();
let prod_2 = x * y * 2.0;
(
prod_1, prod_2,
)
}
fn main() -> Result<()>
{
let dev = Cpu::default();
let x = dev.tensor(1.0);
let y = dev.tensor(2.0);
let tape = Arc::new(Mutex::new(OwnedTape::<f32, Cpu>::default()));
let out = forward(
x.clone()
.put_tape(tape.clone()),
y.clone()
.put_tape(tape.clone()),
);
let grad = out
.1
.backward();
dbg!(grad
.get(&x)
.array());
dbg!(grad
.get(&y)
.array());
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment