Skip to content

Instantly share code, notes, and snippets.

@OhadRubin
Created October 15, 2022 12:06
Show Gist options
  • Save OhadRubin/92e6f6a5e686844420c2e1b17811e463 to your computer and use it in GitHub Desktop.
Save OhadRubin/92e6f6a5e686844420c2e1b17811e463 to your computer and use it in GitHub Desktop.
def hf_features_to_tf_features(features):
output_types= {}
output_shapes = {}
for x,y in features.items():
output_types[x] = getattr(tf.dtypes,y.dtype) if y.dtype!="list" else getattr(tf.dtypes,y.feature.dtype)
output_shapes[x] = [] if y.dtype!="list" else [None]
return dict(output_types=output_types,output_shapes=output_shapes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment