Skip to content

Instantly share code, notes, and snippets.

@Norod
Created August 10, 2022 07:58
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save Norod/3495e86e7224031e1dd071af382d0c86 to your computer and use it in GitHub Desktop.
Converting gpt2-large to onnx with multiple external files and using it later for inference
#!/usr/bin/python
# -*- coding: utf-8 -*-
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig
from transformers.onnx import FeaturesManager, convert, export
from pathlib import Path
import os
model_id = 'gpt2-large'
export_folder = model_id+'-onnx'
print('Loading tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(model_id)
print('Saving tokenizer to ', export_folder)
tokenizer.save_pretrained(export_folder)
print('Loading model...')
model = AutoModelForCausalLM.from_pretrained(model_id)
feature= "causal-lm"
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=feature)
onnx_config = model_onnx_config(model.config)
print("model_kind = {0}\nonx_config = {1}\n".format(model_kind, onnx_config))
onnx_path = Path(export_folder+"/model.onnx")
print('Exporting model to ', onnx_path)
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, onnx_config.default_onnx_opset, onnx_path)
print('Done')
#Tested with the following Python package versions:
#optimum 1.2.3.dev0
#transformers 4.21.0.dev0
#tokenizers 0.11.6
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForCausalLM
from optimum.pipelines import pipeline
model_name="./gpt2-large-onnx"
prompt_text = "Hello, my name is"
generated_max_length = 42
print("Loading model...")
model = ORTModelForCausalLM.from_pretrained(model_name, from_transformers=False)
print('Loading Tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_generator = pipeline(task="text-generation", model=model, tokenizer=tokenizer)
print("Generating text...")
result = text_generator(prompt_text, num_return_sequences=1, batch_size=1, do_sample=True, top_k=40, top_p=0.92, max_length = generated_max_length)
print("result = " + str(result))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment