def torch_gather_nd(params, indices, params_shape, indices_shape):
"""dirty wrapper for tf.gather_nd to use with pytorch.
You will need for this.
warnings.warn("Implemted using tfpyth, thus tensorflow is called in the back")
def func(params, indices):
return tf.gather_nd(params, indices)
out = tfpyth.wrap_torch_from_tensorflow(
["params", "indices"],
input_shapes=[params_shape, indices_shape],
input_dtypes=[tf.float32, tf.int32],
)(params, indices)
return out
