Skip to content

Instantly share code, notes, and snippets.

@emchristiansen
Created July 26, 2023 21:24
Show Gist options
  • Save emchristiansen/8d84b3a36f1333526810e1c99a3a4335 to your computer and use it in GitHub Desktop.
Save emchristiansen/8d84b3a36f1333526810e1c99a3a4335 to your computer and use it in GitHub Desktop.
Demonstration of the difficulty in dealing with gradient tapes
// Demonstration of the difficulty in dealing with gradient tapes for a network
// with multiple inputs and outputs.
use color_eyre::Result;
use dfdx::shapes::Rank0;
use dfdx::tensor::AsArray;
use dfdx::tensor::Cpu;
use dfdx::tensor::PutTape;
use dfdx::tensor::SplitTape;
use dfdx::tensor::Tape;
use dfdx::tensor::Tensor;
use dfdx::tensor::TensorFrom;
use dfdx::tensor::Trace;
use dfdx::tensor::WithEmptyTape;
use dfdx::tensor_ops::Backward;
use dfdx::tensor_ops::TryStack;
// We want to map from Vecs of deeply-nested structs containing Tensors to
// Vecs of deeply-nested structs containing Tensors.
// For simplicity we're just using a type alias for Tensor here and instead of
// using Vecs we're using 2-tuples.
type NestedStruct<T> = Tensor<Rank0, f32, Cpu, T>;
// The network, which takes 2 inputs and has 2 outputs.
// Note each output is differentiable wrt each input (2 x 2).
// We would like to be able to backpropagate from either output.
fn forward<T>(
x: NestedStruct<T>,
y: NestedStruct<T>,
) -> (
NestedStruct<T>,
NestedStruct<T>,
)
where
T: Tape<f32, Cpu>,
{
// The tape isn't Clone, so we have to do a dance to get all the ops into
// the tape.
let prod_1 = x.with_empty_tape() * y.with_empty_tape();
let prod_2 = x * y * 2.0;
// This helper function accumulates the tapes into the first Tensor.
gather_tape(
prod_1, prod_2,
)
}
fn gather_tape<T>(
x: NestedStruct<T>,
y: NestedStruct<T>,
) -> (
NestedStruct<T>,
NestedStruct<T>,
)
where
T: Tape<f32, Cpu>,
{
let (x, x_tape) = x.split_tape();
let (y, y_tape) = y.split_tape();
let merged = x_tape.merge(y_tape);
(
x.put_tape(merged),
y.retaped(),
)
}
fn main() -> Result<()>
{
// Setup the example.
let dev = Cpu::default();
let x = dev.tensor(1.0);
let y = dev.tensor(2.0);
let out = forward(
x.leaky_trace(),
y.leaky_trace(),
);
// Now we can call out.0.backward and get the gradient wrt x and y.
let grad = out
.0
.backward();
// But what if we want the gradient for out.1?
// This naive call doesn't work because out.1 doesn't own the tape.
// let grad = out
// .1
// .backward();
// So, we have to use our helper function again to shift the tape to out.1.
// Except that it doesn't work because we already used out.0.
// let (out_1, out_0) = gather_tape(
// out.1, out.0,
// );
// One might think we could do this:
// let grad = vec![out.0, out.1]
// .stack()
// .backward();
// But, we can't assume we can stack our outputs (they're not generally
// raw Tensors).
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment