Skip to content

Instantly share code, notes, and snippets.

@psycharo-zz
Created February 6, 2018 13:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save psycharo-zz/b91d2a38ab49d67486036565d1550ed4 to your computer and use it in GitHub Desktop.
Save psycharo-zz/b91d2a38ab49d67486036565d1550ed4 to your computer and use it in GitHub Desktop.
tensorflow implementation of permutohedral filtering
# now, let's do tf implementation
def ph_splat(inputs, offsets, weights, nbs):
N, C = inputs.shape
F = weights.shape[1] - 1
M = nbs.shape[0]
weighted_inputs = tf.matmul(weights[:N,:,tf.newaxis],
inputs[:N,tf.newaxis,:])
weighted_inputs = tf.reshape(weighted_inputs, [-1, C])
idxs = tf.reshape(offsets[:N,:F+1], [-1,1])+1
# TODO: the only thing is the unknown shape of M?
# NOTE: the docs say the update is not deterministic,
# but it seems to work
return tf.scatter_nd(idxs, weighted_inputs, [M+2, C])
def ph_blur(inputs, values_in, offsets, weights, nbs):
def _blur_iter(prev, nbs):
n1 = tf.gather(prev, nbs[:,0]+1)
n2 = tf.gather(prev, nbs[:,1]+1)
return prev + 0.5 * tf.pad(n1 + n2, [[1,1], [0,0]])
return tf.foldl(_blur_iter,
tf.transpose(nbs, [1, 0, 2]),
values_in)
# values = values_in
# for j in range(F+1):
# n1 = tf.gather(values, nbs[:,j,0]+1)
# n2 = tf.gather(values, nbs[:,j,1]+1)
# nb_avg = 0.5 * tf.pad(n1 + n2, [[1, 1], [0,0]])
# values = values + nb_avg
# return values
def ph_slice(inputs, values_in, offsets, weights, nbs):
N, C = inputs.shape
F = weights.shape[1]-1
alpha = 1.0 / (1.0 + 2.0**(-F))
idxs = tf.reshape(offsets[:N,:F+1], [-1,])+1
w = weights[:N,:,np.newaxis]
v = tf.reshape(tf.gather(values_in, idxs), [N, F+1, C])
return tf.reduce_sum(alpha * w * v, axis=1)
def ph_filter(inputs, offsets, weights, nbs):
values = ph_splat(inputs, offsets, weights, nbs)
values = ph_blur(inputs, values, offsets, weights, nbs)
return ph_slice(inputs, values, offsets, weights, nbs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment