Created
January 16, 2019 01:13
-
-
Save cshjin/cab7c99a466b6ab25459f080ab3b3073 to your computer and use it in GitHub Desktop.
Verify the tensorflow autogradient and the gradient w.r.t matrix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pickle | |
import matplotlib.pyplot as plt | |
import scipy.io as sio | |
import tensorflow as tf | |
tf.enable_eager_execution() | |
tmm = tf.matmul | |
def _test(): | |
A = tf.random_uniform((2, 2)) | |
zeta = tf.random_normal((2, 2)) | |
W = tf.random_uniform((2, 3)) | |
with tf.GradientTape() as t: | |
t.watch(zeta) | |
z = tf.linalg.norm(tf.matmul(tf.matmul(A, zeta),W))**2 | |
dz_dx = t.gradient(z, zeta) | |
print(dz_dx) | |
print(2*tmm(tmm(tmm(tmm(tf.transpose(A), A), zeta), W), tf.transpose(W))) | |
if __name__ == "__main__": | |
_test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment