Skip to content

Instantly share code, notes, and snippets.

@weimeng23
Last active January 30, 2024 03:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save weimeng23/9c23fa9ee836e15d7108885309a30a22 to your computer and use it in GitHub Desktop.
Save weimeng23/9c23fa9ee836e15d7108885309a30a22 to your computer and use it in GitHub Desktop.
export onnx model
class MyBertForSequenceClassification(BertForSequenceClassification):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
return outputs.logits
def get_dummy_input(seq_length=512):
input_ids = torch.tensor([[i for i in range(seq_length)]], dtype=torch.long)
attention_mask = torch.tensor([[1 for i in range(seq_length)]], dtype=torch.long)
token_type_ids = torch.tensor(
[[0 for i in range(int(seq_length / 2))] + [1 for i in range(seq_length - int(seq_length / 2))]],
dtype=torch.long,
)
return input_ids, attention_mask, token_type_ids
def export_onnx(model, tokenizer, onnx_path, seq_length=512):
dummy_inputs = get_dummy_input(seq_length)
model.eval()
with torch.no_grad():
torch.onnx.export(
model,
dummy_inputs,
onnx_path,
# verbose=True,
opset_version=16,
input_names=['input_ids', 'attention_mask', 'token_type_ids'],
output_names=['output'],
dynamic_axes={
'input_ids': {0: 'batch', 1: 'seq_len'},
'attention_mask': {0: 'batch', 1: 'seq_len'},
'token_type_ids': {0: 'batch', 1: 'seq_len'},
'output': {0: 'batch'},
},
)
def test_onnx(onnx_path):
print()
print('---------------------- test onnx: ', onnx_path)
import numpy as np
onnx_model = BertOnnxWrapper(onnx_path, 1)
inputs = tokenizer(
text,
return_tensors='np',
truncation=True,
max_length=512,
padding='longest',
)
start_time = time.time()
onnx_outputs = onnx_model(inputs)
end_time = time.time()
print(
f'length text is {len(text)}, total time is {end_time - start_time}, per text is {(end_time - start_time) / len(text)}'
)
# print('############## onnx model output: ', onnx_outputs)
np.testing.assert_allclose(onnx_outputs, to_numpy(torch_outputs), rtol=1e-05, atol=1e-08)
def export_model(model, onnx_path, quantize=False):
input_tensor = torch.randn(5, 320000) # 16000 * 20
model.eval()
with torch.no_grad():
torch.onnx.export(
model, # model being run
input_tensor, # model input (or a tuple for multiple inputs)
onnx_path, # where to save the model (can be a file or file-like object)
opset_version=16,
input_names=["input"], # the model's input names
output_names=["output"], # the model's output names
dynamic_axes={
'input': {
0: 'batch_size',
1: 'sequence_length',
}, # variable length axes
'output': {0: 'batch_size'},
},
)
if quantize:
from onnxruntime.quantization import QuantType, quantize_dynamic
quantize_dynamic(
onnx_path,
quant_onnx_path,
op_types_to_quantize=['MatMul'],
weight_type=QuantType.QUInt8,
)
def test_export_model():
model = xxxModel.from_pretrained(
model_name_or_path,
config=config,
)
model.eval()
onnx_fp32_model = OnnxWrapper(onnx_fp32_path)
input_tensor = torch.randn(5, 320000)
with torch.no_grad():
torch_outs = model(input_tensor)
onnx_fp32_outs = onnx_fp32_model(input_tensor.numpy())
np.testing.assert_allclose(
torch_outs.numpy(), onnx_fp32_outs, rtol=1e-03, atol=1e-05
)
def convert_onnx_float32_to_float16(fp32_model_path, fp16_model_path):
from onnxmltools.utils.float16_converter import convert_float_to_float16
from onnxmltools.utils import save_model
model = onnx.load(fp32_model_path)
onnx.checker.check_model(model)
new_onnx_model = convert_float_to_float16(model, keep_io_types=False)
save_model(new_onnx_model, fp16_model_path)
def test_convert_onnx_float32_to_float16():
input_tensor = torch.randn(5, 320000)
fp16_model = OnnxWrapper(onnx_fp16_path)
fp16_outs = fp16_model(input_tensor.numpy().astype(np.float16))
print(
'fp16 ouputs: ', fp16_outs, '\t', fp16_outs[0].shape, '\t', fp16_outs[0].dtype
)
model = xxxModel.from_pretrained(
model_name_or_path,
config=config,
)
model.eval()
torch_outs = model(input_tensor)
np.testing.assert_allclose(torch_outs.numpy(), fp16_outs, rtol=1e-03, atol=1e-05)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment