Skip to content

Instantly share code, notes, and snippets.

@lxuechen
Created August 3, 2018 19:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lxuechen/a60d4e8eb6df52dde43746ffaa2a06ed to your computer and use it in GitHub Desktop.
Save lxuechen/a60d4e8eb6df52dde43746ffaa2a06ed to your computer and use it in GitHub Desktop.
define backward pass
def backward_grads(self, y, dy, training=True):
dy1, dy2 = dy
y1, y2 = y
with tf.GradientTape() as gtape:
gtape.watch(y1)
gy1 = self.g(y1, training=training)
grads_combined = gtape.gradient(
gy1, [y1] + self.g.trainable_variables, output_gradients=dy2)
dg = grads_combined[1:]
dx1 = dy1 + grads_combined[0]
x2 = y2 - gy1
with tf.GradientTape() as ftape:
ftape.watch(x2)
fx2 = self.f(x2, training=training)
grads_combined = ftape.gradient(
fx2, [x2] + self.f.trainable_variables, output_gradients=dx1)
df = grads_combined[1:]
dx2 = dy2 + grads_combined[0]
x1 = y1 - fx2
x = x1, x2
dx = dx1, dx2
grads = df + dg
return x, dx, grads
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment