Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import numpy as np
import tensorflow as tf
N = 5
M = tf.constant(np.arange(N * N).reshape((N, N)), dtype=tf.int32) # N x N tensor, with values from 0 to N-1
# This code implements a fully differentiable summation of all the elements of M where i != j
# inner loop, on columns
def body(i, r):
c = lambda j, s: j < N
b = lambda j, s: [j+1, s + tf.cond(tf.not_equal(i, j), lambda: M[i, j], lambda: 0)]
_, t = tf.while_loop(c, b, [0, r])
return i+1, t
# outer loop, on rows
c = lambda i, r: i < N
x = tf.while_loop(c, body, [0, 0])
res = tf.Session().run(x)[1]
# verify!
M2 = np.arange(N * N).reshape((N, N))
assert res == (np.sum(M2) - np.sum(np.diagonal(M2)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.