Skip to content

Instantly share code, notes, and snippets.

@harpone
Last active August 1, 2023 00:59
Show Gist options
  • Star 46 You must be signed in to star a gist
  • Fork 9 You must be signed in to fork a gist
  • Save harpone/3453185b41d8d985356cbe5e57d67342 to your computer and use it in GitHub Desktop.
Save harpone/3453185b41d8d985356cbe5e57d67342 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np
# Define custom py_func which takes also a grad op as argument:
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))
tf.RegisterGradient(rnd_name)(grad) # see _MySquareGrad for grad example
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": rnd_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
# Def custom square function using np.square instead of tf.square:
def mysquare(x, name=None):
with ops.op_scope([x], name, "Mysquare") as name:
sqr_x = py_func(np.square,
[x],
[tf.float32],
name=name,
grad=_MySquareGrad) # <-- here's the call to the gradient
return sqr_x[0]
# Actual gradient:
def _MySquareGrad(op, grad):
x = op.inputs[0]
return grad * 20 * x # add a "small" error just to see the difference:
with tf.Session() as sess:
x = tf.constant([1., 2.])
y = mysquare(x)
tf.initialize_all_variables().run()
print(x.eval(), y.eval(), tf.gradients(y, x)[0].eval())
@adler-j
Copy link

adler-j commented Feb 25, 2017

This example is apparently old by now, this updated version works:

import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np

# Define custom py_func which takes also a grad op as argument:
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
    
    # Need to generate a unique name to avoid duplicates:
    rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))
    
    tf.RegisterGradient(rnd_name)(grad)  # see _MySquareGrad for grad example
    g = tf.get_default_graph()
    with g.gradient_override_map({"PyFunc": rnd_name}):
        return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

# Def custom square function using np.square instead of tf.square:
def mysquare(x, name=None):
    
    with ops.name_scope(name, "Mysquare", [x]) as name:
        sqr_x = py_func(np.square,
                        [x],
                        [tf.float32],
                        name=name,
                        grad=_MySquareGrad)  # <-- here's the call to the gradient
        return sqr_x[0]

# Actual gradient:
def _MySquareGrad(op, grad):
    x = op.inputs[0]
    return grad * 20 * x  # add a "small" error just to see the difference:

with tf.Session() as sess:
    x = tf.constant([1., 2.])
    y = mysquare(x)
    tf.global_variables_initializer().run()

    print(x.eval(), y.eval(), tf.gradients(y, x)[0].eval())

@IFLED
Copy link

IFLED commented Apr 21, 2017

Function py_func will not work if you call it with argument stateful=False. To make it work you should replace line
with g.gradient_override_map({"PyFunc": rnd_name}):
with
with g.gradient_override_map({"PyFunc": rnd_name, "PyFuncStateless": rnd_name}):

This is because op that applied in tf.py_func has PyFuncStateless name, not PyFunc (in case stateful=False) .

Hope this will help someone.

@IFLED
Copy link

IFLED commented Apr 21, 2017

Example of computing gradient in numpy:

import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np


# Define custom py_func which takes also a grad op as argument:
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):

    # Need to generate a unique name to avoid duplicates:
    rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))

    tf.RegisterGradient(rnd_name)(grad)  # see _MySquareGrad for grad example
    g = tf.get_default_graph()
    with g.gradient_override_map({"PyFunc": rnd_name, "PyFuncStateless": rnd_name}):
        return tf.py_func(func, inp, Tout, stateful=stateful, name=name)


# Def custom square function using np.square instead of tf.square:
def mysquare(x, name=None):

    with ops.name_scope(name, "Mysquare", [x]) as name:
        sqr_x = py_func(np.square,
                        [x],
                        [tf.float32],
                        name=name,
                        grad=_MySquareGrad)  # <-- here's the call to the gradient
        return sqr_x[0]


# Actual gradient:
def _MySquareGrad(op, grad):
    x = op.inputs[0]
    return grad * 20 * x  # add a "small" error just to see the difference:


def mysquare_new(x, name=None):

    with ops.name_scope(name, "MysquareNew", [x]) as name:
        sqr_x, grad = py_func(square_with_grad,
                              [x],
                              [tf.float32, tf.float32],
                              name=name,
                              grad=_MySquareGradNew)
        return sqr_x


def square_with_grad(tensor):
  # first output: x, second output: gradient of sqr_x with respect to x
  return np.square(tensor), 20 * tensor


def _MySquareGradNew(op, grad_sqr, grad_grad):
    # grad_sqr - gradient of some global function with respect to the first output of the op
    # grad_grad - gradient of some global function with respect to the second output of the op
    # op.outputs[0] - tensor that equals to op.inputs[0] * op.inputs[0]
    # op.outputs[1] - tensor that equals to 20 * op.inputs[0]
    return grad_sqr * op.outputs[1]


with tf.Session() as sess:
    x = tf.constant([1., 2.])
    y = mysquare(x)
    y_new = mysquare(x)
    tf.global_variables_initializer().run()

    print(x.eval(), y.eval(), y_new.eval(), tf.gradients(y, x)[0].eval(), tf.gradients(y_new, x)[0].eval())

@shubham0704
Copy link

Thanks a lot @harpone

@JasZhanAva
Copy link

This really helps, thanks!

@nayansinghal
Copy link

I want a behavior in which during the forward pass, it will do quantization and for the backward pass, it will behave as an identity and pass the gradients as it is without considering the gradient of that layer.
Is there any method in Keras which can solve this issue?

@chaow94
Copy link

chaow94 commented Nov 22, 2017

@nayansinghal I also want to do so . And you have any idea for :

pass the gradients as it is without considering the gradient of that layer
?

@chaow94
Copy link

chaow94 commented Nov 22, 2017

@nayansinghal I think tf.stop_gradient() maybe work . And I will try with it.

@we-taper
Copy link

we-taper commented Feb 5, 2018

@IFLED I think you mean:

y_new = mysquare_new(x)

in the end.

Also, an alternative way to calculate gradient using numpy is to create a new py_func. I have tried the following code (which could be improved, it is just for demonstrating the idea):

    # Actual gradient:
    # Replace @harpone's _MySquareGrad function with this:
    def _MySquareGrad(op, grad):
        x = op.inputs[0]
        tmp_grad_name = 'tmp_grad_name'+ str(np.random.randint(low=0,high=1e+8))
        grad_x = tf.py_func(func=_grad, inp=[x], Tout=[tf.float32], stateful=True, name=tmp_grad_name)[0]
        return grad * grad_x 
    def _grad(self, x:np.ndarray) -> np.ndarray:
        return 2*x

@kristijanbartol
Copy link

kristijanbartol commented May 3, 2018

A complete minimalistic example with actual gradient updates could also be useful: https://gist.github.com/kristijanbartol/1b7b7c5d431415284217bbf63ca25c66

import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np
import time

ZERO_TOL = 1e-8
LOSS_TOL = 1e-3
SAMPLES = 100
EPOCHS = 100000

train_input = np.random.rand(SAMPLES)
train_label = 3 * train_input


class MyException(Exception):
    pass


def _my_linear_grad(op, grad):
    # second value is not used - it can be multiplied by zero with no side effects
    return grad * op.inputs[1], grad * 0.


def _my_linear(a, x):
    return (a * x).astype(np.float32)


learning_rate = 1e-3
beta1 = 0.9999

x = tf.placeholder(dtype=tf.float32, shape=(), name='x')
y = tf.placeholder(dtype=tf.float32, shape=(), name='y')

a = tf.get_variable('a', dtype=tf.float32, initializer=1.)
tf_a = tf.get_variable('tf_a', dtype=tf.float32, initializer=1.)

with ops.op_scope([a, x], name="MyLinear") as name:
    # custom gradient op name shouldn't conflict with any other TF op name
    unique_name = 'PyFuncGrad@Unique'
    # using tf.RegisterGradient to set _my_linear_grad function in backward pass for gradient op named rnd_name
    tf.RegisterGradient(unique_name)(_my_linear_grad)

    g = tf.get_default_graph()

    # context manager used to override gradients for nodes created in its block
    with g.gradient_override_map({"PyFunc": unique_name}):
        # my_linear is used for forward pass - my_linear and my_linear_grad are wrapped inside a single TF node
        p = tf.py_func(_my_linear, [a, x], [tf.float32], stateful=True, name=name)

tf_p = tf_a * x

loss = tf.reduce_mean(tf.square(p - y))
tf_loss = tf.reduce_mean(tf.square(tf_p - y))

train_vars = [var for var in tf.trainable_variables()]
optim = tf.train.AdamOptimizer(learning_rate, beta1)

# compute_gradients returns a list so I can just concatenate them to calculate tf_loss, too
grads_and_vars = optim.compute_gradients(loss, var_list=train_vars)
grads_and_vars += optim.compute_gradients(tf_loss, var_list=train_vars)
train_op = optim.apply_gradients(grads_and_vars)

tf.summary.scalar('loss', loss)

with tf.Session() as sess:
    train_writer = tf.summary.FileWriter('board', sess.graph)
    merge = tf.summary.merge_all()

    sess.run(tf.global_variables_initializer())

    try:
        for epoch in range(EPOCHS):
            overall_loss = 0.
            # update using each sample separately
            for i in range(SAMPLES):
                result = sess.run([loss, tf_loss, a, tf_a, merge, train_op], feed_dict={
                    x: train_input[i],
                    y: train_label[i]
                })

                if np.abs(result[0] - result[1]) > ZERO_TOL:
                    print('Invalid update!\nExpected: {}, Actual: {}'.format(result[1], result[0]))
                    raise MyException

                print('epoch: {}, iter: {}, loss: {}\na: {}\n'.format(epoch, i, result[0], result[2]))
                overall_loss += result[0]

            overall_loss /= float(SAMPLES)
            print('overall_loss: {}'.format(overall_loss))
            #time.sleep(2.0)

            # [NOTE] this moment will be delayed a bit as it has to "wait" for the epoch to finish
            if overall_loss < LOSS_TOL:
                print('Found parameter!\n---------------\n')
                break

    except MyException:
        pass

@nicola-calonaci
Copy link

nicola-calonaci commented Sep 4, 2018

Thank you very much guys. All this helped me a lot. Two more questions:

  1. could you say why examples by @IFLED and @adler-j give me the error: 'cannot create weak reference to "numpy.ufunc" object' while with @kristijanbartol it's alright?
  2. is it possible to deal with n-dimensional gradient? In that case how does the '# second value is not used - it can be multiplied by zero with no side effects' thing behave?

Possible solution to 2:
Using a py_func for the custom gradient?

@samlobel
Copy link

samlobel commented Feb 2, 2019

@Nirzi: The "cannot create weak reference" error comes because we're passing np.square as the argument to the custom py_func, which is then passed to tf.py_func. The problem can be solved by instead defining a new function

def numpy_square(x):
    return np.square(x)

and passing that named function to py_func, instead of the "weak reference" to a numpy function. To be honest I'm not super clear on why, but I think it has to do with how Python does object lookups.

@Musawar71
Copy link

hello guys, i can not print in grad function. How to make sure that my calculation would go fine

@hlx-hub
Copy link

hlx-hub commented Jan 21, 2020

hi,can you show me how to use it to train the networks such as CNN,I am a beginer in this region,so when I try to train it use this method ,there always be a lot of mistakes.thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment