Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
def torch_gather_nd(params, indices, params_shape, indices_shape):
"""dirty wrapper for tf.gather_nd to use with pytorch.
You will need https://github.com/theRealSuperMario/tfpyth for this.
"""
warnings.warn("Implemted using tfpyth, thus tensorflow is called in the back")
def func(params, indices):
return tf.gather_nd(params, indices)
out = tfpyth.wrap_torch_from_tensorflow(
func,
["params", "indices"],
input_shapes=[params_shape, indices_shape],
input_dtypes=[tf.float32, tf.int32],
)(params, indices)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment