Created
August 12, 2024 07:50
-
-
Save bash99/f2061e2d268b8df5717a61459958263c to your computer and use it in GitHub Desktop.
speed test for onnx backend vs sentence_transformer backend of BCE Reranker model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from sentence_transformers import CrossEncoder | |
from transformers import AutoTokenizer | |
import onnxruntime | |
import numpy as np | |
from copy import deepcopy | |
from typing import List | |
from abc import ABC, abstractmethod | |
import os | |
import concurrent.futures | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import time | |
from datasets import load_dataset | |
LOCAL_RERANK_MAX_LENGTH = 512 | |
LOCAL_RERANK_WORKERS = 4 | |
LOCAL_RERANK_BATCH = 32 | |
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' | |
def get_time(func): | |
def inner(*arg, **kwargs): | |
s_time = time.time() | |
res = func(*arg, **kwargs) | |
e_time = time.time() | |
print('函数 {} 执行耗时: {} 秒'.format(func.__name__, e_time - s_time)) | |
return res | |
return inner | |
class RerankBackend(ABC): | |
def __init__(self, model_path, model_name, use_cpu=False): | |
self.use_cpu = use_cpu | |
self.model_file_path = os.path.join(model_path, f"{model_name}.onnx") | |
self._tokenizer = AutoTokenizer.from_pretrained(model_path) | |
self.spe_id = self._tokenizer.sep_token_id | |
self.overlap_tokens = 80 | |
self.batch_size = LOCAL_RERANK_BATCH | |
self.max_length = LOCAL_RERANK_MAX_LENGTH | |
self.return_tensors = None | |
self.workers = LOCAL_RERANK_WORKERS | |
@abstractmethod | |
def inference(self, batch) -> List: | |
pass | |
def merge_inputs(self, chunk1_raw, chunk2): | |
chunk1 = deepcopy(chunk1_raw) | |
chunk1['input_ids'].extend(chunk2['input_ids']) | |
chunk1['input_ids'].append(self.spe_id) | |
chunk1['attention_mask'].extend(chunk2['attention_mask']) | |
chunk1['attention_mask'].append(chunk2['attention_mask'][0]) | |
if 'token_type_ids' in chunk1: | |
token_type_ids = [1 for _ in range(len(chunk2['token_type_ids']) + 1)] | |
chunk1['token_type_ids'].extend(token_type_ids) | |
return chunk1 | |
def tokenize_preproc(self, | |
query: str, | |
passages: List[str], | |
): | |
query_inputs = self._tokenizer.encode_plus(query, truncation=False, padding=False) | |
max_passage_inputs_length = self.max_length - len(query_inputs['input_ids']) - 1 | |
assert max_passage_inputs_length > 10 | |
overlap_tokens = min(self.overlap_tokens, max_passage_inputs_length * 2 // 7) | |
# 组[query, passage]对 | |
merge_inputs = [] | |
merge_inputs_idxs = [] | |
for pid, passage in enumerate(passages): | |
passage_inputs = self._tokenizer.encode_plus(passage, truncation=False, padding=False, | |
add_special_tokens=False) | |
passage_inputs_length = len(passage_inputs['input_ids']) | |
if passage_inputs_length <= max_passage_inputs_length: | |
if passage_inputs['attention_mask'] is None or len(passage_inputs['attention_mask']) == 0: | |
continue | |
qp_merge_inputs = self.merge_inputs(query_inputs, passage_inputs) | |
merge_inputs.append(qp_merge_inputs) | |
merge_inputs_idxs.append(pid) | |
else: | |
start_id = 0 | |
while start_id < passage_inputs_length: | |
end_id = start_id + max_passage_inputs_length | |
sub_passage_inputs = {k: v[start_id:end_id] for k, v in passage_inputs.items()} | |
start_id = end_id - overlap_tokens if end_id < passage_inputs_length else end_id | |
qp_merge_inputs = self.merge_inputs(query_inputs, sub_passage_inputs) | |
merge_inputs.append(qp_merge_inputs) | |
merge_inputs_idxs.append(pid) | |
return merge_inputs, merge_inputs_idxs | |
# @get_time | |
def get_rerank(self, query: str, passages: List[str]): | |
tot_batches, merge_inputs_idxs_sort = self.tokenize_preproc(query, passages) | |
tot_scores = [] | |
with concurrent.futures.ThreadPoolExecutor(max_workers=self.workers) as executor: | |
futures = [] | |
for k in range(0, len(tot_batches), self.batch_size): | |
batch = self._tokenizer.pad( | |
tot_batches[k:k + self.batch_size], | |
padding=True, | |
max_length=None, | |
pad_to_multiple_of=None, | |
return_tensors=self.return_tensors | |
) | |
future = executor.submit(self.inference, batch) | |
futures.append(future) | |
# print(f'rerank number: {len(futures)}') | |
for future in futures: | |
scores = future.result() | |
tot_scores.extend(scores) | |
merge_tot_scores = [0 for _ in range(len(passages))] | |
for pid, score in zip(merge_inputs_idxs_sort, tot_scores): | |
merge_tot_scores[pid] = max(merge_tot_scores[pid], score) | |
# print("merge_tot_scores:", merge_tot_scores, flush=True) | |
return merge_tot_scores | |
class RerankOnnxBackend(RerankBackend): | |
def __init__(self, model_path: str, model_name: str, use_cpu: bool = False): | |
super().__init__(model_path, model_name, use_cpu) | |
self.return_tensors = "np" | |
# 创建一个ONNX Runtime会话设置,使用GPU执行 | |
sess_options = onnxruntime.SessionOptions() | |
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
if use_cpu: | |
providers = ['CPUExecutionProvider'] | |
else: | |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
self.session = onnxruntime.InferenceSession(self.model_file_path, sess_options, providers=providers) | |
def inference(self, batch): | |
# 准备输入数据 | |
inputs = {self.session.get_inputs()[0].name: batch['input_ids'], | |
self.session.get_inputs()[1].name: batch['attention_mask']} | |
if 'token_type_ids' in batch: | |
inputs[self.session.get_inputs()[2].name] = batch['token_type_ids'] | |
# 执行推理 输出为logits | |
result = self.session.run(None, inputs) # None表示获取所有输出 | |
# debug_logger.info(f"rerank result: {result}") | |
# 应用sigmoid函数 | |
sigmoid_scores = 1 / (1 + np.exp(-np.array(result[0]))) | |
return sigmoid_scores.reshape(-1).tolist() | |
class Reranker: | |
def rerank(self, query, passages, top_k=32768): | |
raise NotImplementedError | |
class CrossEncoderReranker(Reranker): | |
def __init__(self, model_name): | |
self.reranker_model = CrossEncoder(model_name, max_length=512, automodel_args={"torch_dtype": torch.float16}) | |
def rerank(self, query, passages, top_k=32768): | |
score_inputs = [[query, passage] for passage in passages] | |
scores = self.reranker_model.predict(score_inputs) | |
# result = [{'question': passage, 'score': score} for passage, score in zip(passages, scores)] | |
result = [{'idx': idx, 'question': passage, 'score': score} for idx, (passage, score) in enumerate(zip(passages, scores))] | |
sorted_result = sorted(result, key=lambda x: x['score'], reverse=True) | |
return sorted_result[:top_k] | |
class OnnxReranker(Reranker): | |
def __init__(self, remote_url_or_path, model_name): | |
self.onnx_reranker = RerankOnnxBackend(remote_url_or_path, model_name) | |
def rerank(self, query, passages, top_k=32768): | |
scores = self.onnx_reranker.get_rerank(query, passages) | |
indexed_results = [{'question': passages[idx], 'score': score, 'idx': idx} for idx, score in enumerate(scores)] | |
sorted_result = sorted(indexed_results, key=lambda x: x['score'], reverse=True) | |
return sorted_result[:top_k] | |
def reranker_factory(reranker_type, remote_url_or_path='./netease-youdao/Rerank', model_name='rerank'): | |
if reranker_type == 'cross_encoder': | |
return CrossEncoderReranker(model_name=model_name) | |
elif reranker_type == 'onnx': | |
return OnnxReranker(remote_url_or_path=remote_url_or_path, model_name=model_name) | |
else: | |
raise ValueError("Unknown reranker type") | |
@get_time | |
def test_reranker_speed(qas, reranker, max_concurrent=1): | |
with ThreadPoolExecutor(max_workers=max_concurrent) as executor: | |
batch_size = 32 | |
futures = [] | |
for i in range(0, len(qas), batch_size): | |
passages = [qa['question'] for qa in qas[i:i + batch_size]] | |
query = passages[0] | |
future = executor.submit(reranker.rerank, query, passages) | |
futures.append(future) | |
for future in as_completed(futures): | |
try: | |
future.result() | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
if __name__ == '__main__': | |
reranker = reranker_factory('onnx', remote_url_or_path='./netease-youdao/Rerank', model_name='rerank') | |
# 基本测试 | |
query = "Where is Munich?" | |
passages = ["The sky is blue.", "Munich is in Germany.", 'Munich hosted the 1972 Olympics Game'] | |
res = reranker.rerank(query, passages) | |
print(res) | |
dataset = load_dataset("sentence-transformers/all-nli", 'pair') | |
print(len(dataset['test'])) | |
qas = [{'question': sentence} for sentence in dataset['test'][:5000]['anchor']] | |
test_reranker_speed(qas, reranker, max_concurrent=1) | |
test_reranker_speed(qas, reranker, max_concurrent=2) | |
test_reranker_speed(qas, reranker, max_concurrent=4) | |
reranker = reranker_factory('cross_encoder', model_name='maidalun1020/bce-reranker-base_v1') | |
res = reranker.rerank(query, passages) | |
print(res) | |
test_reranker_speed(qas, reranker, max_concurrent=1) | |
test_reranker_speed(qas, reranker, max_concurrent=2) | |
test_reranker_speed(qas, reranker, max_concurrent=4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment