Skip to content

Instantly share code, notes, and snippets.

@emchristiansen
Created July 26, 2023 21:47
Show Gist options
  • Save emchristiansen/db80f5e85c791f6bb5bba5b78b750cd9 to your computer and use it in GitHub Desktop.
Save emchristiansen/db80f5e85c791f6bb5bba5b78b750cd9 to your computer and use it in GitHub Desktop.
Proposed gradient tape design pattern
use color_eyre::Result;
use dfdx::shapes::Rank0;
use dfdx::tensor::AsArray;
use dfdx::tensor::Cpu;
use dfdx::tensor::NoneTape;
use dfdx::tensor::OwnedTape;
use dfdx::tensor::PutTape;
use dfdx::tensor::SplitTape;
use dfdx::tensor::Tape;
use dfdx::tensor::Tensor;
use dfdx::tensor::TensorFrom;
use dfdx::tensor::Trace;
use dfdx::tensor::WithEmptyTape;
use dfdx::tensor_ops::Backward;
use dfdx::tensor_ops::TryStack;
type Tensor_ = Tensor<Rank0, f32, Cpu, NoneTape>;
// Another design for forward fns.
// The status quo is to associate the tape with the inputs and outputs, and
// you have to remember which input has the tape and which output should have
// it.
// Instead, why not just pass it in and out explicitly?
fn forward<T>(
tape: T,
x: Tensor_,
y: Tensor_,
) -> (
T,
Tensor_,
Tensor_,
)
where
T: Tape<f32, Cpu>,
{
let (prod_1, tape) = (x
.clone()
.put_tape(tape)
* y.clone())
.split_tape();
let (prod_2, tape) = (x.put_tape(tape) * y * 2.0).split_tape();
(
tape, prod_1, prod_2,
)
}
fn main() -> Result<()>
{
let dev = Cpu::default();
let x = dev.tensor(1.0);
let y = dev.tensor(2.0);
let (tape, out0, out1) = forward(
OwnedTape::default(),
x.clone(),
y.clone(),
);
// Now we can decide at the call-site which gradient we want.
let grad = out1
.put_tape(tape)
.backward();
dbg!(&grad
.get(&x)
.array());
dbg!(&grad
.get(&y)
.array());
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment