Skip to content

Instantly share code, notes, and snippets.

@theRealSuperMario
Created April 16, 2020 15:01
Show Gist options
  • Save theRealSuperMario/4e06152ac8845335c90c70943c0f184d to your computer and use it in GitHub Desktop.
Save theRealSuperMario/4e06152ac8845335c90c70943c0f184d to your computer and use it in GitHub Desktop.
def torch_gather_nd(params, indices, params_shape, indices_shape):
"""dirty wrapper for tf.gather_nd to use with pytorch.
You will need https://github.com/theRealSuperMario/tfpyth for this.
"""
warnings.warn("Implemted using tfpyth, thus tensorflow is called in the back")
def func(params, indices):
return tf.gather_nd(params, indices)
out = tfpyth.wrap_torch_from_tensorflow(
func,
["params", "indices"],
input_shapes=[params_shape, indices_shape],
input_dtypes=[tf.float32, tf.int32],
)(params, indices)
return out
@y6216886
Copy link

AttributeError: module 'tfpyth' has no attribute 'wrap_torch_from_tensorflow'
Maybe the project has been updated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment