Skip to content

Instantly share code, notes, and snippets.

@hotchpotch
Created April 9, 2024 00:40
Show Gist options
  • Save hotchpotch/64fa52d32886fe61cc1d110066afef38 to your computer and use it in GitHub Desktop.
Save hotchpotch/64fa52d32886fe61cc1d110066afef38 to your computer and use it in GitHub Desktop.
ONNX model to float16 precision
"""
This script converts an ONNX model to float16 precision using the onnxruntime transformers package.
It takes an input ONNX model file as a mandatory argument. The output file name is optional; if not provided,
the script generates the output file name by appending "_fp16" to the base name of the input file.
"""
import argparse
import onnx
from onnxruntime.transformers.float16 import convert_float_to_float16
import os
def main(input_file, output_file=None):
# Check if the input file exists
if not os.path.exists(input_file):
print(f"Error: The input file '{input_file}' does not exist.")
return
# Generate the output file name from the input file name if not specified
if output_file is None:
base_name = os.path.splitext(input_file)[0] # Get the file name without the extension
output_file = f"{base_name}_fp16.onnx"
print(f"Loading model from {input_file}...")
onnx_model = onnx.load(input_file)
print("Converting model to float16...")
model_fp16 = convert_float_to_float16(onnx_model, disable_shape_infer=True)
print(f"Saving converted model to {output_file}...")
onnx.save(model_fp16, output_file)
print("Conversion completed successfully.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert an ONNX model to float16.")
parser.add_argument("-i", "--input", required=True, help="Input ONNX model file.")
parser.add_argument("-o", "--output", required=False, help="Optional output file for the converted model. If not specified, derives the output file name from the input file name.")
args = parser.parse_args()
main(args.input, args.output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment