Created
April 16, 2020 15:01
-
-
Save theRealSuperMario/4e06152ac8845335c90c70943c0f184d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def torch_gather_nd(params, indices, params_shape, indices_shape): | |
"""dirty wrapper for tf.gather_nd to use with pytorch. | |
You will need https://github.com/theRealSuperMario/tfpyth for this. | |
""" | |
warnings.warn("Implemted using tfpyth, thus tensorflow is called in the back") | |
def func(params, indices): | |
return tf.gather_nd(params, indices) | |
out = tfpyth.wrap_torch_from_tensorflow( | |
func, | |
["params", "indices"], | |
input_shapes=[params_shape, indices_shape], | |
input_dtypes=[tf.float32, tf.int32], | |
)(params, indices) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
AttributeError: module 'tfpyth' has no attribute 'wrap_torch_from_tensorflow'
Maybe the project has been updated.