Skip to content

Instantly share code, notes, and snippets.

@oborchers
Last active April 3, 2021 15:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oborchers/631e796ce52058010d5ef8624f2f0cb4 to your computer and use it in GitHub Desktop.
Save oborchers/631e796ce52058010d5ef8624f2f0cb4 to your computer and use it in GitHub Desktop.
def print_transformers_shape_inference(name_or_path: str):
"""Prints the transformers shape inference for onnx."""
res = {}
model_pipeline = transformers.FeatureExtractionPipeline(
model=transformers.AutoModel.from_pretrained(name_or_path),
tokenizer=transformers.AutoTokenizer.from_pretrained(
name_or_path, use_fast=True
),
framework="pt",
device=-1,
)
with torch.no_grad():
(
input_names,
output_names,
dynamic_axes,
tokens,
) = convert_graph_to_onnx.infer_shapes(model_pipeline, "pt")
ordered_input_names, model_args = convert_graph_to_onnx.ensure_valid_input(
model_pipeline.model, tokens, input_names
)
res["input_names"] = input_names
res["output_names"] = output_names
res["dynamic_axes"] = dynamic_axes
res["tokens"] = tokens
res["exemplary_input"] = model_args
print()
print(f"Inferred shapes for {name_or_path}")
print(f"Input names: {input_names}")
print(f"Output names: {output_names}")
print(f"Dynamic Axes:\n{json.dumps(dynamic_axes,sort_keys=True, indent=4)}")
print(f"Tokens:{tokens}")
print(f"Ordered input names: {ordered_input_names}")
print(f"Arguments: {model_args}")
return res
model_args = print_transformers_shape_inference(model_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment