Skip to content

Instantly share code, notes, and snippets.

@bash99
Created August 12, 2024 07:50
Show Gist options
  • Save bash99/f2061e2d268b8df5717a61459958263c to your computer and use it in GitHub Desktop.
Save bash99/f2061e2d268b8df5717a61459958263c to your computer and use it in GitHub Desktop.
speed test for onnx backend vs sentence_transformer backend of BCE Reranker model
import torch
from sentence_transformers import CrossEncoder
from transformers import AutoTokenizer
import onnxruntime
import numpy as np
from copy import deepcopy
from typing import List
from abc import ABC, abstractmethod
import os
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from datasets import load_dataset
LOCAL_RERANK_MAX_LENGTH = 512
LOCAL_RERANK_WORKERS = 4
LOCAL_RERANK_BATCH = 32
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
def get_time(func):
def inner(*arg, **kwargs):
s_time = time.time()
res = func(*arg, **kwargs)
e_time = time.time()
print('函数 {} 执行耗时: {} 秒'.format(func.__name__, e_time - s_time))
return res
return inner
class RerankBackend(ABC):
def __init__(self, model_path, model_name, use_cpu=False):
self.use_cpu = use_cpu
self.model_file_path = os.path.join(model_path, f"{model_name}.onnx")
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
self.spe_id = self._tokenizer.sep_token_id
self.overlap_tokens = 80
self.batch_size = LOCAL_RERANK_BATCH
self.max_length = LOCAL_RERANK_MAX_LENGTH
self.return_tensors = None
self.workers = LOCAL_RERANK_WORKERS
@abstractmethod
def inference(self, batch) -> List:
pass
def merge_inputs(self, chunk1_raw, chunk2):
chunk1 = deepcopy(chunk1_raw)
chunk1['input_ids'].extend(chunk2['input_ids'])
chunk1['input_ids'].append(self.spe_id)
chunk1['attention_mask'].extend(chunk2['attention_mask'])
chunk1['attention_mask'].append(chunk2['attention_mask'][0])
if 'token_type_ids' in chunk1:
token_type_ids = [1 for _ in range(len(chunk2['token_type_ids']) + 1)]
chunk1['token_type_ids'].extend(token_type_ids)
return chunk1
def tokenize_preproc(self,
query: str,
passages: List[str],
):
query_inputs = self._tokenizer.encode_plus(query, truncation=False, padding=False)
max_passage_inputs_length = self.max_length - len(query_inputs['input_ids']) - 1
assert max_passage_inputs_length > 10
overlap_tokens = min(self.overlap_tokens, max_passage_inputs_length * 2 // 7)
# 组[query, passage]对
merge_inputs = []
merge_inputs_idxs = []
for pid, passage in enumerate(passages):
passage_inputs = self._tokenizer.encode_plus(passage, truncation=False, padding=False,
add_special_tokens=False)
passage_inputs_length = len(passage_inputs['input_ids'])
if passage_inputs_length <= max_passage_inputs_length:
if passage_inputs['attention_mask'] is None or len(passage_inputs['attention_mask']) == 0:
continue
qp_merge_inputs = self.merge_inputs(query_inputs, passage_inputs)
merge_inputs.append(qp_merge_inputs)
merge_inputs_idxs.append(pid)
else:
start_id = 0
while start_id < passage_inputs_length:
end_id = start_id + max_passage_inputs_length
sub_passage_inputs = {k: v[start_id:end_id] for k, v in passage_inputs.items()}
start_id = end_id - overlap_tokens if end_id < passage_inputs_length else end_id
qp_merge_inputs = self.merge_inputs(query_inputs, sub_passage_inputs)
merge_inputs.append(qp_merge_inputs)
merge_inputs_idxs.append(pid)
return merge_inputs, merge_inputs_idxs
# @get_time
def get_rerank(self, query: str, passages: List[str]):
tot_batches, merge_inputs_idxs_sort = self.tokenize_preproc(query, passages)
tot_scores = []
with concurrent.futures.ThreadPoolExecutor(max_workers=self.workers) as executor:
futures = []
for k in range(0, len(tot_batches), self.batch_size):
batch = self._tokenizer.pad(
tot_batches[k:k + self.batch_size],
padding=True,
max_length=None,
pad_to_multiple_of=None,
return_tensors=self.return_tensors
)
future = executor.submit(self.inference, batch)
futures.append(future)
# print(f'rerank number: {len(futures)}')
for future in futures:
scores = future.result()
tot_scores.extend(scores)
merge_tot_scores = [0 for _ in range(len(passages))]
for pid, score in zip(merge_inputs_idxs_sort, tot_scores):
merge_tot_scores[pid] = max(merge_tot_scores[pid], score)
# print("merge_tot_scores:", merge_tot_scores, flush=True)
return merge_tot_scores
class RerankOnnxBackend(RerankBackend):
def __init__(self, model_path: str, model_name: str, use_cpu: bool = False):
super().__init__(model_path, model_name, use_cpu)
self.return_tensors = "np"
# 创建一个ONNX Runtime会话设置,使用GPU执行
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
if use_cpu:
providers = ['CPUExecutionProvider']
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self.session = onnxruntime.InferenceSession(self.model_file_path, sess_options, providers=providers)
def inference(self, batch):
# 准备输入数据
inputs = {self.session.get_inputs()[0].name: batch['input_ids'],
self.session.get_inputs()[1].name: batch['attention_mask']}
if 'token_type_ids' in batch:
inputs[self.session.get_inputs()[2].name] = batch['token_type_ids']
# 执行推理 输出为logits
result = self.session.run(None, inputs) # None表示获取所有输出
# debug_logger.info(f"rerank result: {result}")
# 应用sigmoid函数
sigmoid_scores = 1 / (1 + np.exp(-np.array(result[0])))
return sigmoid_scores.reshape(-1).tolist()
class Reranker:
def rerank(self, query, passages, top_k=32768):
raise NotImplementedError
class CrossEncoderReranker(Reranker):
def __init__(self, model_name):
self.reranker_model = CrossEncoder(model_name, max_length=512, automodel_args={"torch_dtype": torch.float16})
def rerank(self, query, passages, top_k=32768):
score_inputs = [[query, passage] for passage in passages]
scores = self.reranker_model.predict(score_inputs)
# result = [{'question': passage, 'score': score} for passage, score in zip(passages, scores)]
result = [{'idx': idx, 'question': passage, 'score': score} for idx, (passage, score) in enumerate(zip(passages, scores))]
sorted_result = sorted(result, key=lambda x: x['score'], reverse=True)
return sorted_result[:top_k]
class OnnxReranker(Reranker):
def __init__(self, remote_url_or_path, model_name):
self.onnx_reranker = RerankOnnxBackend(remote_url_or_path, model_name)
def rerank(self, query, passages, top_k=32768):
scores = self.onnx_reranker.get_rerank(query, passages)
indexed_results = [{'question': passages[idx], 'score': score, 'idx': idx} for idx, score in enumerate(scores)]
sorted_result = sorted(indexed_results, key=lambda x: x['score'], reverse=True)
return sorted_result[:top_k]
def reranker_factory(reranker_type, remote_url_or_path='./netease-youdao/Rerank', model_name='rerank'):
if reranker_type == 'cross_encoder':
return CrossEncoderReranker(model_name=model_name)
elif reranker_type == 'onnx':
return OnnxReranker(remote_url_or_path=remote_url_or_path, model_name=model_name)
else:
raise ValueError("Unknown reranker type")
@get_time
def test_reranker_speed(qas, reranker, max_concurrent=1):
with ThreadPoolExecutor(max_workers=max_concurrent) as executor:
batch_size = 32
futures = []
for i in range(0, len(qas), batch_size):
passages = [qa['question'] for qa in qas[i:i + batch_size]]
query = passages[0]
future = executor.submit(reranker.rerank, query, passages)
futures.append(future)
for future in as_completed(futures):
try:
future.result()
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == '__main__':
reranker = reranker_factory('onnx', remote_url_or_path='./netease-youdao/Rerank', model_name='rerank')
# 基本测试
query = "Where is Munich?"
passages = ["The sky is blue.", "Munich is in Germany.", 'Munich hosted the 1972 Olympics Game']
res = reranker.rerank(query, passages)
print(res)
dataset = load_dataset("sentence-transformers/all-nli", 'pair')
print(len(dataset['test']))
qas = [{'question': sentence} for sentence in dataset['test'][:5000]['anchor']]
test_reranker_speed(qas, reranker, max_concurrent=1)
test_reranker_speed(qas, reranker, max_concurrent=2)
test_reranker_speed(qas, reranker, max_concurrent=4)
reranker = reranker_factory('cross_encoder', model_name='maidalun1020/bce-reranker-base_v1')
res = reranker.rerank(query, passages)
print(res)
test_reranker_speed(qas, reranker, max_concurrent=1)
test_reranker_speed(qas, reranker, max_concurrent=2)
test_reranker_speed(qas, reranker, max_concurrent=4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment