Created
July 26, 2023 21:24
-
-
Save emchristiansen/8d84b3a36f1333526810e1c99a3a4335 to your computer and use it in GitHub Desktop.
Demonstration of the difficulty in dealing with gradient tapes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Demonstration of the difficulty in dealing with gradient tapes for a network | |
// with multiple inputs and outputs. | |
use color_eyre::Result; | |
use dfdx::shapes::Rank0; | |
use dfdx::tensor::AsArray; | |
use dfdx::tensor::Cpu; | |
use dfdx::tensor::PutTape; | |
use dfdx::tensor::SplitTape; | |
use dfdx::tensor::Tape; | |
use dfdx::tensor::Tensor; | |
use dfdx::tensor::TensorFrom; | |
use dfdx::tensor::Trace; | |
use dfdx::tensor::WithEmptyTape; | |
use dfdx::tensor_ops::Backward; | |
use dfdx::tensor_ops::TryStack; | |
// We want to map from Vecs of deeply-nested structs containing Tensors to | |
// Vecs of deeply-nested structs containing Tensors. | |
// For simplicity we're just using a type alias for Tensor here and instead of | |
// using Vecs we're using 2-tuples. | |
type NestedStruct<T> = Tensor<Rank0, f32, Cpu, T>; | |
// The network, which takes 2 inputs and has 2 outputs. | |
// Note each output is differentiable wrt each input (2 x 2). | |
// We would like to be able to backpropagate from either output. | |
fn forward<T>( | |
x: NestedStruct<T>, | |
y: NestedStruct<T>, | |
) -> ( | |
NestedStruct<T>, | |
NestedStruct<T>, | |
) | |
where | |
T: Tape<f32, Cpu>, | |
{ | |
// The tape isn't Clone, so we have to do a dance to get all the ops into | |
// the tape. | |
let prod_1 = x.with_empty_tape() * y.with_empty_tape(); | |
let prod_2 = x * y * 2.0; | |
// This helper function accumulates the tapes into the first Tensor. | |
gather_tape( | |
prod_1, prod_2, | |
) | |
} | |
fn gather_tape<T>( | |
x: NestedStruct<T>, | |
y: NestedStruct<T>, | |
) -> ( | |
NestedStruct<T>, | |
NestedStruct<T>, | |
) | |
where | |
T: Tape<f32, Cpu>, | |
{ | |
let (x, x_tape) = x.split_tape(); | |
let (y, y_tape) = y.split_tape(); | |
let merged = x_tape.merge(y_tape); | |
( | |
x.put_tape(merged), | |
y.retaped(), | |
) | |
} | |
fn main() -> Result<()> | |
{ | |
// Setup the example. | |
let dev = Cpu::default(); | |
let x = dev.tensor(1.0); | |
let y = dev.tensor(2.0); | |
let out = forward( | |
x.leaky_trace(), | |
y.leaky_trace(), | |
); | |
// Now we can call out.0.backward and get the gradient wrt x and y. | |
let grad = out | |
.0 | |
.backward(); | |
// But what if we want the gradient for out.1? | |
// This naive call doesn't work because out.1 doesn't own the tape. | |
// let grad = out | |
// .1 | |
// .backward(); | |
// So, we have to use our helper function again to shift the tape to out.1. | |
// Except that it doesn't work because we already used out.0. | |
// let (out_1, out_0) = gather_tape( | |
// out.1, out.0, | |
// ); | |
// One might think we could do this: | |
// let grad = vec![out.0, out.1] | |
// .stack() | |
// .backward(); | |
// But, we can't assume we can stack our outputs (they're not generally | |
// raw Tensors). | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment