Skip to content

Instantly share code, notes, and snippets.

@tsbertalan
Last active August 26, 2021 17:38
Show Gist options
  • Save tsbertalan/b6c02bf6e39116d8446faa0159a011af to your computer and use it in GitHub Desktop.
Save tsbertalan/b6c02bf6e39116d8446faa0159a011af to your computer and use it in GitHub Desktop.
Demonstrate tensorflow's `custom_gradient` for a polynomial op.
import tensorflow as tf, numpy as np, matplotlib.pyplot as plt
trainable = []
order = 4
@tf.custom_gradient
def poly(x):
# Create the (or get a handle to the existing existing) polynomial coefficients variable
# that we're supposed to learn.
with tf.variable_scope('polynomial', reuse=tf.AUTO_REUSE):
p = tf.get_variable('poly_coeffs', [order+1,], use_resource=True,
initializer=tf.initializers.ones())
if p not in trainable:
trainable.append(p)
# Evaluate the polynomial (the naive way; see Numerical Recipies for a better).
#
# Of course, we're not *really* evaluating it with this Python code; only setting up,
# through operater overloading on the `Tensor` object x, the computational graph
# that TensorFlow will compile to GPU code which will actually do the evaluation
# when we use `feed_dict` to pass particular values for x.
with tf.variable_scope('polyval'):
# constant term
poly = p[-1]
for k in range(order):
# linear, quadratic, ...
poly += x ** (k+1) * p[-(k+2)]
def grad_fn(dpoly, variables=None):
# Evaluate the derivatve of the poly output WRT the input x.
with tf.variable_scope('dydx'):
# linear becomes constant term
grad_xs = p[-2]
for k in range(order-1):
# quadratic, cubic, ... become linear, quadratic, ...
coefficient = p[-(k+3)]
raw_exponent = k + 2
grad_xs += coefficient * raw_exponent * x ** (raw_exponent - 1)
# Actually, TensorFlow wants not the true deriviatives dy/dt, but the action of them, DY*dy/dt,
# for backpropagation purposes (that is, the row vector DY left-multiplied by the matrix dy/dx).
# In other applications, this might be easier to compute than the full derivatives, like a Krylov method
# (if y is L things and x is M things, dy/dx is L-by-M, but DY*dy/dt is only M-by).
grad_xs = grad_xs * dpoly
# Alternately, we can let TF do this for us:
#grad_xs = tf.gradients(poly, x, grad_ys=dpoly)[0]
# Evaluate the derivative of the polynomial output WRT the learned parameters p.
with tf.variable_scope('dydp'):
# If x had been a vector output, this *would* be a full matrix, as far as I can tell.
# (actually a 3-tensor, where the first index is across a batch).
grad_vars = []
# We could let TF do this for us:
#if variables is not None:
# for v in variables:
# print('Getting dy/d(%s) ... ' % v)
# grad_vars = tf.gradients(poly, variables)
# But, to be explicit, I'll do it manually.
if variables is not None:
for v in variables:
if v is p:
dydp = [
x ** (order-k)
for k in range(order+1)
]
# Note that we have to do a reduce_sum
# over the batch dimension here,
# since, despite the name, that's what
# `custom_gradient` really wants.
grad_vars.append(tf.reduce_sum(
tf.concat(dydp, 1),
axis=0,
))
else:
grad_vars.append(None)
return grad_xs, grad_vars
return poly, grad_fn
if __name__ == '__main__':
sess = tf.InteractiveSession()
# Select evaluation independent variable values.
x_data = np.random.normal(loc=0, scale=1, size=(2000,))
x_linspace = np.linspace(x_data.min(), x_data.max(), 1000)
# Make a quartic dependent variable.
p = .1, 0, 1, 0, 0
# Use our op.
x_in = tf.placeholder(tf.float32, shape=[None, 1], name='x')
y_pred = poly(x_in)
# Set the parameters to the correct values for our quartic.
p_tensor = trainable[0]
sess.run(p_tensor.assign(p));
#### Verify that the forward pass works correctly.
fig, ax = plt.subplots()
ax.plot(x_linspace, np.polyval(p, x_linspace), label='true')
ax.plot(x_linspace, sess.run(y_pred, feed_dict={x_in: x_linspace.reshape((-1, 1))}), label='net',
linewidth=10, alpha=.5)
ax.set_ylabel('$y$')
ax.legend(); ax.set_xlabel('$x$');
#### Verify that the gradient WRT the parameters works correctly.
g = tf.gradients(y_pred, trainable[0])[0]
print([2**(4-k) for k in range(5)]) # [16, 8, 4, 2, 1]
print(sess.run(g, feed_dict={x_in: np.array([2]).reshape((-1, 1))})) # [16. 8. 4. 2. 1.]
# Evaluate the per-example gradients.
gv = np.array([sess.run(g, feed_dict={x_in: x.reshape((1, 1))}) for x in x_linspace])
print(gv.shape) # (1000, 5)
fig, ax = plt.subplots()
k = 2
ax.plot(x_linspace, x_linspace ** (4-k), label='true')
ax.plot(x_linspace, gv[:, k], label='net',
linewidth=10, alpha=.5)
ax.set_ylabel(r'$\partial y/\partial p_{%d}$' % k)
ax.legend(); ax.set_xlabel('$x$');
# Try to use a faster helper function to do the same; doesn't work.
from tensorflow.python.ops.parallel_for import jacobian, batch_jacobian
J = jacobian(y_pred, trainable[0])
print(J) # Tensor("Reshape_1:0", shape=(?, 1, 5), dtype=float32)
# This seems good, but ...
Jv = sess.run(J, feed_dict={x_in: x_linspace.reshape((-1, 1))})
print(Jv[::100, :, 0]) # This gives all nearly the same value, though different every run.
#### Verify that the gradient WRT the input works correctly.
# Manually compute the true derivative coefficients.
pder = np.array(p)
for k in range(len(pder)):
pder[k] *= order - k
pder = pder[:-1]
fig, ax = plt.subplots()
ax.plot(x_linspace, np.polyval(pder, x_linspace), label='true $dy/dx$');
dydx = tf.gradients(y_pred, x_in)[0]
ax.plot(
x_linspace,
sess.run(dydx, feed_dict={x_in: x_linspace.reshape(-1, 1)}).ravel(),
linewidth=8, alpha=.5,
label='$dy/dx$ of current ANN'
)
ax.legend();
plt.show()
@tushar-dalal
Copy link

For TensorFlow 2.x, this code needs some alterations. Would you like to update it? I already settled the one with me, I could do a pull request.

@louislung
Copy link

For TensorFlow 2.x, this code needs some alterations. Would you like to update it? I already settled the one with me, I could do a pull request.

I'm about to do the same and then see your comment,, will you share it on git? thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment