Skip to content

Instantly share code, notes, and snippets.

@Seven-Streams
Last active May 5, 2026 08:58
Show Gist options
  • Select an option

  • Save Seven-Streams/75e16fd33cd29f92024f86b167f1b98e to your computer and use it in GitHub Desktop.

Select an option

Save Seven-Streams/75e16fd33cd29f92024f86b167f1b98e to your computer and use it in GitHub Desktop.
Accuracy Evaluation of Qwen3.6 for XGrammar-2

Evaluate tool-calling accuracy and efficiency on SGLang with Structural Tag

The evaluation script is modified based on the BFCL ast checker. The script uses the Structural Tag API to test tool-calling accuracy and efficiency against an SGLang OpenAI-compatible server. The script is modified on https://github.com/Irfnfnkemed/eval_tool_call. Please put the BVCL's JSON files in data/dataset.

Test the accuracy

You can use bash script.sh directly to test the accuracy. You can also use the following commands manually:

First launch the server.

python -m sglang.launch_server --model-path Qwen/Qwen3.6-27B --host 127.0.0.1 --port 8000

Than generate the raw data (w/ & w/o structural tag):

cd ./tool_call_eval
python accuracy.py --model Qwen/Qwen3.6-27B \
--tokenizer Qwen/Qwen3.6-27B \
--dataset BFCL_v3_simple --dataset-path ./data/dataset --num-gpus 1 \
--num-requests 400 --num-warmup-requests 1 --request-rate inf \
--host 127.0.0.1 --port 8000 \
--api-endpoint sglang --output ./data/accuracy_raw \
--temperature 0.001 --top-p 0.9 \
[--use-stag]

The raw data will be in ./data/accuracy_raw directory. Finally process the raw data:

python check.py --dataset ALL --model ALL --dataset-path ./data/dataset \
--output-root ./data/accuracy_raw --final-root ./data/accuracy_summary
"""Tool-calling accuracy benchmark entrypoint."""
import argparse
import functools
import json
import logging
import os
import random
import re
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from transformers import AutoTokenizer # pylint: disable=import-error
from api_endpoint import SUPPORTED_BACKENDS, create_api_endpoint
from dataset import SUPPORTED_DATASET, Dataset, create_dataset
from request_processor import (
MetricAnalyzer,
RequestProcessor,
create_pipelines,
)
from request_record import (
RequestRecord,
convert_reports_to_df,
generate_metrics_summary,
pretty_print_report,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
from sglang.srt.function_call.function_call_parser import FunctionCallParser
except ImportError: # pragma: no cover - fallback for older sglang
FunctionCallParser = None
def _parse_xml_parameter_value(raw_value: str) -> Any:
raw_value = raw_value.strip()
if not raw_value:
return ""
try:
return json.loads(raw_value)
except json.JSONDecodeError:
return raw_value
def _parse_abc_calls_with_qwen_xml(
stringified_calls: str, tools: Optional[List[Dict[str, Any]]]
) -> List[Dict[str, Any]]:
if FunctionCallParser is not None and tools is not None:
try:
parser = FunctionCallParser(tools=tools, tool_call_parser="qwen_xml")
_, calls = parser.parse_non_stream(stringified_calls)
parsed = []
for call in calls:
name = getattr(call, "name", None)
if name is None and isinstance(call, dict):
name = call.get("name")
arguments = getattr(call, "parameters", None)
if arguments is None and isinstance(call, dict):
arguments = call.get("parameters") or call.get("arguments")
if name is None or arguments is None:
continue
parsed.append({"function": {"name": name, "arguments": arguments}})
if parsed:
return parsed
except Exception: # pylint: disable=broad-exception-caught
pass
pattern = re.compile(
r"<tool_call>\s*<function=([^>\n]+)>\s*(.*?)</function>\s*</tool_call>",
re.DOTALL,
)
parameter_pattern = re.compile(
r"<parameter=([^>\n]+)>\s*(.*?)\s*</parameter>", re.DOTALL
)
parsed = []
for match in pattern.finditer(stringified_calls):
func_name = match.group(1).strip()
func_body = match.group(2)
arguments = {}
for parameter_match in parameter_pattern.finditer(func_body):
key = parameter_match.group(1).strip()
value = _parse_xml_parameter_value(parameter_match.group(2))
arguments[key] = value
parsed.append({"function": {"name": func_name, "arguments": arguments}})
return parsed
def _parse_num_concurrent_requests(num_str: Optional[str]) -> Optional[List[int]]:
if num_str is None:
return None
numbers = num_str.split(",")
if any(not number.isdigit() for number in numbers):
raise ValueError(f"Unrecognized num_concurrent_requests list: {numbers}")
return list(int(number) for number in numbers)
def _parse_request_rate(request_rate_str: Optional[str]) -> Optional[List[np.float32]]:
if request_rate_str is None:
return None
request_rates = request_rate_str.split(",")
results = []
for rate_str in request_rates:
request_rate = float(rate_str)
if request_rate <= 0:
raise ValueError(f"Invalid request rate {request_rate}")
results.append(np.float32(request_rate))
return results
def convert_calls_to_json(
stringified_calls: str, model: str, tools: Optional[List[Dict[str, Any]]] = None
) -> List[Dict[str, Any]]:
"""Convert stringified tool calls to a list of dicts."""
if "Qwen3.6" in model:
return _parse_abc_calls_with_qwen_xml(stringified_calls, tools)
function_calls_json = []
if "Llama-3" in model:
start = 0
while True:
index = stringified_calls.find('{"name":', start)
if index == -1:
break
try:
decoder = json.JSONDecoder()
result, end_index = decoder.raw_decode(stringified_calls, index)
except: # pylint: disable=bare-except
start = index + 1
continue
start = end_index
if (
not isinstance(result, dict)
or "name" not in result
or "parameters" not in result
):
continue
function_calls_json.append(
{
"function": {
"name": result["name"],
"arguments": result["parameters"],
}
}
)
elif "Qwen2" in model:
start = 0
while True:
index = stringified_calls.find('<tool_call>\n{"name":', start)
if index == -1:
break
try:
decoder = json.JSONDecoder()
result, end_index = decoder.raw_decode(
stringified_calls, index + len("<tool_call>\n")
)
except: # pylint: disable=bare-except
start = index + 1
continue
start = end_index
if (
not isinstance(result, dict)
or "name" not in result
or "arguments" not in result
):
continue
function_calls_json.append(
{"function": {"name": result["name"], "arguments": result["arguments"]}}
)
return function_calls_json
def run_pipeline(
pipeline: RequestProcessor,
dataset: Dataset,
tokenizer: AutoTokenizer,
args: argparse.Namespace,
) -> Tuple[Dict[str, Any], List[RequestRecord]]:
"""Run the pipeline with the given dataset and args. Return the benchmark report dict."""
random.seed(args.seed)
np.random.seed(args.seed)
request_records = dataset.generate_request_records(
args.input_len,
args.output_len,
args.input_len_std,
args.output_len_std,
)
request_records = pipeline(request_records)
num_total_requests = (
args.num_requests
if not args.per_gpu_workload
else args.num_requests * args.num_gpus
)
assert len(request_records) == num_total_requests
sorted_requests: List[RequestRecord] = [None] * num_total_requests
for request_record in request_records:
assert request_record.request_id is not None
assert sorted_requests[request_record.request_id] is None
sorted_requests[request_record.request_id] = request_record
request_records = MetricAnalyzer(tokenizer)(request_records)
report = generate_metrics_summary(
request_records, num_total_requests, args.num_gpus
)
return report, sorted_requests
def main(args: argparse.Namespace):
"""Main benchmark entrance."""
if args.num_requests <= 0:
raise ValueError("Number of requests to benchmark must be positive.")
def _main():
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
dataset = create_dataset(args, tokenizer)
dataset.require_fake_warmup = True
f_create_api_endpoint = functools.partial(create_api_endpoint, args)
pipelines = create_pipelines(args, f_create_api_endpoint, dataset)
store_record = []
model_part_name = args.model.split("/")[-1]
output_dir = f"{args.output}/{model_part_name}/{args.dataset}/{'use_stag' if args.use_stag else 'no_stag'}/"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for i, pipeline in enumerate(pipelines):
report, request_records = run_pipeline(pipeline, dataset, tokenizer, args)
for request in request_records:
store_record.append({"id": request.request_id})
store_record[-1]["output"] = request.output_str
store_record[-1]["call"] = convert_calls_to_json(
request.output_str,
args.model,
dataset.gorilla_data[request.request_id]["tool"],
)
with open(f"{output_dir}/result.json", "w") as f:
json.dump(store_record, f, indent=4)
_main()
if __name__ == "__main__":
parser = argparse.ArgumentParser("SGLang tool-calling accuracy benchmark")
parser.add_argument(
"--dataset",
type=str,
choices=SUPPORTED_DATASET,
help=f"The benchmark dataset kind. Supporting {SUPPORTED_DATASET}",
)
parser.add_argument(
"--dataset-path",
type=str,
help="The dataset file path.",
)
parser.add_argument(
"--api-endpoint",
type=str,
choices=SUPPORTED_BACKENDS,
default="sglang",
help="The API endpoint API for benchmarking.",
)
parser.add_argument(
"--model",
type=str,
required=True,
help="The name of the model.",
)
parser.add_argument(
"--tokenizer",
type=str,
required=True,
help="The path of the tokenizer directory.",
)
parser.add_argument(
"--num-gpus",
type=int,
required=True,
help="The number of GPUs used by the server. "
"We need this to better analyze the throughput per GPU.",
)
parser.add_argument(
"--num-requests",
type=int,
required=True,
help="The number of requests for benchmark.",
)
parser.add_argument(
"--num-warmup-requests",
type=int,
help="The number of requests for warmup. "
"It is optional when fixing the number of concurrent requests, and is required otherwise.",
)
parser.add_argument(
"--per-gpu-workload",
default=False,
action="store_true",
help='When set to True, the specified "num_concurrent_requests"/"request_rate" '
"denote the workload **per GPU**, which means that the real values of "
'"num_concurrent_requests"/"request_rate" used in benchmark'
'will be multiplied by "num_gpus".',
)
parser.add_argument(
"--num-concurrent-requests",
type=_parse_num_concurrent_requests,
help="The number(s) of concurrent requests to benchmark. "
'It can be either one integer or a list of integer separated by commas(","). '
"When specified, for each integer, the benchmark keeps these many consistent "
"number of concurrently running requests.",
)
parser.add_argument(
"--request-rate",
type=_parse_request_rate,
help="The request rate(s) denoting the number of new requests each second. "
'It can be either one float number (or "inf") or a list of numbers separated '
'by commas(","). '
"When specified, the benchmark sends these many new requests each second. "
'If it is "inf", all requests will be sent together at once.',
)
parser.add_argument(
"--replay-timestamp-scale",
type=float,
help="The timestamp scale when replaying the timestamps in a dataset. "
'The dataset replay mode is enabled when neither "--num-concurrent-requests" and '
'"--request-rate" is specified. '
"The scale is 1 by default in the replay mode.",
)
parser.add_argument(
"--input-len",
type=int,
help="The benchmark request average input length. Default to None, "
"which means the request input length depends on the dataset being used.",
)
parser.add_argument(
"--input-len-std",
type=float,
default=0,
help="The benchmark request input length standard deviation. Default to 0.",
)
parser.add_argument(
"--output-len",
type=int,
help="The benchmark request average output length. Default to None, "
"which means the request output length depends on the dataset being used.",
)
parser.add_argument(
"--output-len-std",
type=float,
default=0,
help="The benchmark request output length standard deviation. Default to 0.",
)
parser.add_argument(
"--stream",
action="store_true",
default=False,
help="Whether to benchmark stream responses. "
"When not enabled, metrics such as time-to-first-token (TTFT) will not be available. "
"Default to False.",
)
parser.add_argument(
"--include-server-metrics",
action="store_true",
help="Whether to include server-side request metrics when the endpoint provides them.",
)
parser.add_argument(
"--host",
type=str,
required=True,
help="The host address of the backend API.",
)
parser.add_argument(
"--port",
type=int,
required=True,
help="The port of the backend API.",
)
parser.add_argument(
"--timeout",
type=float,
default=3 * 60 * 60,
help="The timeout limit of each request.",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="The random number seed. Default to 0.",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="The temperature value for logit adjustment. Default to 1.",
)
parser.add_argument(
"--top-p",
type=float,
default=1.0,
help="The top-p value for sampling. Default to 1.",
)
parser.add_argument(
"--ignore-eos",
default=False,
action="store_true",
help='Whether to set the "ignore_eos" field.',
)
parser.add_argument(
"--apply-chat-template",
default=False,
action="store_true",
help="Whether to apply chat template to the request input text. "
'It is not supported when "--input-len" is specified.',
)
parser.add_argument(
"--num-process-workers",
type=int,
help="The number of parallel process workers to send the requests.",
)
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Whether to disable showing progress bar with tqdm during benchmarking.",
)
parser.add_argument(
"--max-schedule-gap",
type=float,
default=0.5,
help="The maximum allowed delay between the scheduled time in seconds.",
)
parser.add_argument(
"--cuda-profile",
default=False,
action="store_true",
help="Whether to enable CUDA profiling on the SGLang server debug endpoint.",
)
parser.add_argument(
"--debug-dump",
default=False,
action="store_true",
help="Whether to dump all request record raw data to file.",
)
parser.add_argument(
"--multi-round",
default=False,
action="store_true",
help="Whether to chat like multi round conversion with history log each request. "
"Only enabled when benchmarked with fixed concurrent request mode."
"The --num-concurrent-requests should be provided when enabling this option.",
)
parser.add_argument(
"--output",
"-o",
type=str,
default="sglang_accuracy",
help="The path of the output file where to dump the benchmark results.",
)
parser.add_argument(
"--use-stag",
action="store_true",
help="Whether to set structural tag.",
)
parser.add_argument(
"--use-jf",
action="store_true",
help="Whether to use jump-forward-decoding.",
)
main(parser.parse_args())
"""Benchmark backends."""
import argparse
import json
import os
import time
import traceback
from typing import Literal, Optional
from typing_extensions import Self
from request_record import Metrics, RequestRecord, ServerMetrics
import logging
logger = logging.getLogger(__name__)
class APIEndPoint:
"""Manages the sending of requests to a specified API endpoint and gathers
inference statistics.
"""
def __init__(self, include_server_metrics: bool = False) -> None:
self.include_server_metrics = include_server_metrics
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc_value, tb) -> None:
pass
async def __call__(self, request: RequestRecord) -> RequestRecord:
raise NotImplementedError()
class OpenAIChatEndPoint(APIEndPoint):
"""The backend of sending HTTP requests in OpenAI API through "v1/chat/completions"."""
def __init__( # pylint: disable=too-many-arguments
self,
host: str,
port: int,
api_type: Literal["sglang"],
timeout: Optional[float] = None,
include_server_metrics: bool = False,
) -> None:
super().__init__(include_server_metrics=include_server_metrics)
import aiohttp # pylint: disable=import-outside-toplevel,import-error
self.timeout = timeout
self.client: aiohttp.ClientSession = None
self.url = f"http://{host}:{port}/v1/chat/completions"
self.headers = {"Content-Type": "application/json"}
api_key = os.getenv("SGLANG_API_KEY") or os.getenv("OPENAI_API_KEY")
if api_key:
self.headers["Authorization"] = f"Bearer {api_key}"
self.api_type: str = api_type
async def __aenter__(self) -> Self:
import aiohttp # pylint: disable=import-outside-toplevel,import-error
self.client = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(self.timeout))
return self
async def __aexit__(self, exc_type, exc_value, tb) -> None:
await self.client.close()
async def __call__( # pylint: disable=too-many-branches,too-many-statements,too-many-locals
self, request_record: RequestRecord
) -> RequestRecord:
payload = request_record.chat_cmpl.model_dump(exclude_none=True)
payload.setdefault("stream", False)
if self.timeout is not None and "timeout" not in payload:
payload["timeout"] = self.timeout
if self.include_server_metrics:
if "stream_options" not in payload or payload["stream_options"] is None:
payload["stream_options"] = {"include_usage": True}
else:
payload["stream_options"]["include_usage"] = True
debug_config = payload.pop("debug_config", None)
if debug_config is not None and debug_config.get("ignore_eos"):
payload["ignore_eos"] = True
response_format = payload.get("response_format")
if (
response_format is not None
and response_format["type"] == "structural_tag"
and "tags" in response_format
):
response_format["structures"] = response_format.pop("tags")
for tag in response_format["structures"]:
if isinstance(tag.get("schema"), str):
tag["schema"] = json.loads(tag["schema"])
generated_text = ""
first_chunk_output_str = ""
time_to_first_token_s = None
start_time = time.monotonic()
server_metrics = None
try:
async with self.client.post(
self.url, json=payload, headers=self.headers
) as response:
assert response.status == 200, await response.text()
if payload["stream"]:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk or chunk == b"\n":
continue
# Get rid of the prefix "data: " and suffix "\n"
raw_data = chunk[6:].strip()
if raw_data == b"[DONE]":
continue
data = json.loads(raw_data)
if not data["choices"]:
continue
delta = data["choices"][0]["delta"]
content = delta.get("content", None)
if content is not None and not time_to_first_token_s:
time_to_first_token_s = time.monotonic() - start_time
first_chunk_output_str = content
usage = data.get("usage")
if self.include_server_metrics and usage is not None:
# fmt: off
# pylint: disable=line-too-long
server_metrics = ServerMetrics(
input_tokens=usage["extra"]["prompt_tokens"],
prefill_tokens=usage["extra"]["prefill_tokens"],
output_tokens=usage["extra"]["completion_tokens"],
end_to_end_latency_s=usage["extra"]["end_to_end_latency_s"],
prefill_tokens_per_s=usage["extra"]["prefill_tokens_per_s"],
inter_token_latency_s=usage["extra"]["inter_token_latency_s"],
time_per_output_token_s=1 / usage["extra"]["decode_tokens_per_s"],
time_to_first_token_s=usage["extra"]["ttft_s"],
)
# pylint: enable=line-too-long
# fmt: on
if content is not None:
generated_text += content
else:
data = await response.json()
generated_text = data["choices"][0]["message"]["content"]
usage = data.get("usage")
if self.include_server_metrics and usage is not None:
# fmt: off
# pylint: disable=line-too-long
server_metrics = ServerMetrics(
input_tokens=usage["extra"]["prompt_tokens"],
prefill_tokens=usage["extra"]["prefill_tokens"],
output_tokens=usage["extra"]["completion_tokens"],
end_to_end_latency_s=usage["extra"]["end_to_end_latency_s"],
prefill_tokens_per_s=usage["extra"]["prefill_tokens_per_s"],
inter_token_latency_s=usage["extra"]["inter_token_latency_s"],
time_per_output_token_s=1 / usage["extra"]["decode_tokens_per_s"],
time_to_first_token_s=usage["extra"]["ttft_s"],
)
# pylint: enable=line-too-long
# fmt: on
except Exception: # pylint: disable=broad-except
error_msg = (
"API endpoint errored when sending request: " + traceback.format_exc()
)
logger.info(error_msg)
finish_time = time.monotonic()
request_record.output_str = generated_text
request_record.first_chunk_output_str = first_chunk_output_str
request_record.metrics = Metrics(
success=False,
start_time=start_time,
finish_time=finish_time,
end_to_end_latency_s=finish_time - start_time,
input_tokens=request_record.metrics.input_tokens,
time_to_first_token_s=time_to_first_token_s,
server_metrics=server_metrics,
exec_feature=request_record.metrics.exec_feature,
)
request_record.error_msg = error_msg
return request_record
finish_time = time.monotonic()
request_record.output_str = generated_text
request_record.first_chunk_output_str = first_chunk_output_str
success = True
error_msg = None
if len(generated_text) == 0:
success = False
error_msg = "Empty generated text."
request_record.metrics = Metrics(
success=success,
start_time=start_time,
finish_time=finish_time,
end_to_end_latency_s=finish_time - start_time,
input_tokens=request_record.metrics.input_tokens,
time_to_first_token_s=time_to_first_token_s,
server_metrics=server_metrics,
exec_feature=request_record.metrics.exec_feature,
)
request_record.error_msg = error_msg
return request_record
SUPPORTED_BACKENDS = ["sglang"]
def create_api_endpoint(args: argparse.Namespace) -> APIEndPoint:
"""Create an API endpoint instance with regard to the specified endpoint kind."""
if args.api_endpoint == "sglang":
return OpenAIChatEndPoint(
args.host,
args.port,
args.api_endpoint,
args.timeout,
args.include_server_metrics,
)
raise ValueError(f'Unrecognized endpoint "{args.api_endpoint}"')
"""Tool-calling accuracy result checker."""
import argparse
import json
import os
import re
from typing import Dict, Any, Tuple, List, Optional
SUPPORTED_DATASET = [
"BFCL_v3_simple",
"BFCL_v3_multiple",
"BFCL_v3_parallel",
"BFCL_v3_live_simple",
"BFCL_v3_live_multiple",
"BFCL_v3_live_parallel",
"ALL",
]
SUPPORTED_MODEL = [
"Llama-3.2-1B-Instruct-q0f16-MLC",
"Llama-3.2-3B-Instruct-q0f16-MLC",
"Llama-3.1-8B-Instruct-q0f16-MLC",
"Llama-3.1-70B-Instruct-q0f16-MLC",
"Qwen2.5-72B-Instruct-q0f16-MLC",
"Qwen3.6-27B",
"Qwen3.6-35B-A3B",
"ALL",
]
from enum import IntEnum
try:
from sglang.srt.function_call.function_call_parser import FunctionCallParser
except ImportError: # pragma: no cover - fallback for older sglang
FunctionCallParser = None
class Err_type(IntEnum):
CALL_NUMBER_ERROR = 0
FUNC_SELECT_ERROR = 1
FUNC_NAME_ERROR = 2
PARA_KEY_ERROR = 3
TYPE_ERROR = 4
ENUM_ERROR = 5
PARA_VALUE_ERROR = 6
NONE = 7
class Error:
def __init__(self, message: str = "", err_type: Err_type = Err_type.NONE):
self.message = message
self.error_type = err_type
def _parse_xml_parameter_value(raw_value: str) -> Any:
raw_value = raw_value.strip()
if not raw_value:
return ""
try:
return json.loads(raw_value)
except json.JSONDecodeError:
return raw_value
def _parse_Qwen_calls_with_qwen_xml(
output: str, tools: Optional[List[Dict[str, Any]]]
) -> List[Dict[str, Any]]:
if FunctionCallParser is not None and tools is not None:
try:
parser = FunctionCallParser(tools=tools, tool_call_parser="qwen_xml")
_, calls = parser.parse_non_stream(output)
parsed = []
for call in calls:
name = getattr(call, "name", None)
if name is None and isinstance(call, dict):
name = call.get("name")
arguments = getattr(call, "parameters", None)
if arguments is None and isinstance(call, dict):
arguments = call.get("parameters") or call.get("arguments")
if name is None or arguments is None:
continue
parsed.append({"function": {"name": name, "arguments": arguments}})
if parsed:
return parsed
except Exception: # pylint: disable=broad-exception-caught
pass
pattern = re.compile(
r"<tool_call>\s*<function=([^>\n]+)>\s*(.*?)</function>\s*</tool_call>",
re.DOTALL,
)
parameter_pattern = re.compile(
r"<parameter=([^>\n]+)>\s*(.*?)\s*</parameter>", re.DOTALL
)
parsed = []
for match in pattern.finditer(output):
func_name = match.group(1).strip()
func_body = match.group(2)
arguments = {}
for parameter_match in parameter_pattern.finditer(func_body):
key = parameter_match.group(1).strip()
value = _parse_xml_parameter_value(parameter_match.group(2))
arguments[key] = value
parsed.append({"function": {"name": func_name, "arguments": arguments}})
return parsed
def _parse_calls_for_model(
model: str, output: str, tools: Optional[List[Dict[str, Any]]]
) -> List[Dict[str, Any]]:
if "Qwen3.6" in model:
return _parse_Qwen_calls_with_qwen_xml(output, tools)
parsed_calls = []
start = 0
while True:
index = output.find('{"name":', start)
if index == -1:
break
try:
decoder = json.JSONDecoder()
result, end_index = decoder.raw_decode(output, index)
except json.JSONDecodeError:
start = index + 1
continue
start = end_index + 1
if "Llama-3" in model:
if "name" not in result or "parameters" not in result:
continue
parsed_calls.append(
{
"function": {
"name": result["name"],
"arguments": result["parameters"],
}
}
)
elif "Qwen2" in model:
if "name" not in result or "arguments" not in result:
continue
parsed_calls.append(
{
"function": {
"name": result["name"],
"arguments": result["arguments"],
}
}
)
return parsed_calls
def valid_data_point(tools: List[Dict], expected: List[Dict]) -> bool:
# check expected call-function name
valid_func_name = set()
for tool in tools:
valid_func_name.add(tool["function"]["name"])
for call in expected:
if call["name"] not in valid_func_name:
return False
# check the enum schema
results = []
def _find(obj, target_type: str):
if isinstance(obj, dict):
for key, value in obj.items():
if isinstance(value, dict):
if "type" in value and value["type"] == target_type:
results.append(value)
if isinstance(value, (dict, list)):
_find(value, target_type)
elif isinstance(obj, list):
for item in obj:
_find(item, target_type)
for tool in tools:
results = []
_find(tool, "array")
for item in results:
if "enum" in item:
for entry in item["enum"]:
if not isinstance(entry, list):
return False
for tool in tools:
results = []
_find(tool, "integer")
for item in results:
if "enum" in item:
for entry in item["enum"]:
if not isinstance(entry, int):
return False
return True
# Modified by https://github.com/ShishirPatil/gorilla/blob/main/berkeley-function-call-leaderboard/bfcl/eval_checker/ast_eval/ast_checker.py
def check_simple(
gorilla, tool_call: Dict[str, Any], tool: Dict[str, Any], ideal: Dict[str, Any]
) -> Tuple[bool, Error]:
# check func name
if ideal["name"] != tool_call["function"]["name"]:
return False, Error("wrong function name.", Err_type.FUNC_NAME_ERROR)
func = tool["function"]
# check func args
if not isinstance(tool_call["function"]["arguments"], dict):
return False, Error("wrong format", Err_type.PARA_KEY_ERROR)
for arg in func["parameters"]["required"]:
if arg not in tool_call["function"]["arguments"]:
return False, Error(f"missing arg: {arg}", Err_type.PARA_KEY_ERROR)
for arg in tool_call["function"]["arguments"].keys():
ideal_arg: List = (
ideal["arguments"][arg] if arg in ideal["arguments"] else None
)
real_arg = tool_call["function"]["arguments"][arg]
if arg not in func["parameters"]["properties"]:
return False, Error(f"unknown arg: {arg}", Err_type.PARA_KEY_ERROR)
info_arg = func["parameters"]["properties"][arg]
if info_arg["type"] == "integer":
acc, err = check_integer(gorilla, real_arg, ideal_arg)
if not acc:
return False, err
elif info_arg["type"] == "number":
acc, err = check_number(gorilla, real_arg, ideal_arg)
if not acc:
return False, err
elif info_arg["type"] == "boolean":
acc, err = check_boolean(gorilla, real_arg, ideal_arg)
if not acc:
return False, err
elif info_arg["type"] == "string":
# XML tool calls parse parameter values with json.loads; string slots may
# arrive as non-str, and will be converted to str automatically.
pass
elif info_arg["type"] == "array":
acc, err = check_list(gorilla, real_arg, ideal_arg, info_arg["items"])
if not acc:
return False, err
elif info_arg["type"] == "dict":
acc, err = check_dict(
gorilla, real_arg, ideal_arg, info_arg["properties"]
)
if not acc:
return False, err
return True, Error()
def check_simple_schema(
gorilla, tool_call: Dict[str, Any], tool: Dict[str, Any]
) -> Tuple[bool, Error]:
# check func name
func = tool["function"]
if func["name"] != tool_call["function"]["name"]:
return False, Error("wrong function name.", Err_type.FUNC_NAME_ERROR)
# check func args
if not isinstance(tool_call["function"]["arguments"], dict):
return False, Error("wrong format", Err_type.PARA_KEY_ERROR)
for arg in func["parameters"]["required"]:
if arg not in tool_call["function"]["arguments"]:
return False, Error(f"missing arg: {arg}", Err_type.PARA_KEY_ERROR)
for arg in tool_call["function"]["arguments"].keys():
real_arg = tool_call["function"]["arguments"][arg]
if arg not in func["parameters"]["properties"]:
return False, Error(f"unknown arg: {arg}", Err_type.PARA_KEY_ERROR)
info_arg = func["parameters"]["properties"][arg]
if info_arg["type"] == "integer":
acc, err = check_integer(gorilla, real_arg, None)
if not acc:
return False, err
elif info_arg["type"] == "number":
acc, err = check_number(gorilla, real_arg, None)
if not acc:
return False, err
elif info_arg["type"] == "boolean":
acc, err = check_boolean(gorilla, real_arg, None)
if not acc:
return False, err
elif info_arg["type"] == "string":
pass
elif info_arg["type"] == "array":
acc, err = check_list(gorilla, real_arg, None, info_arg["items"])
if not acc:
return False, err
elif info_arg["type"] == "dict":
acc, err = check_dict(gorilla, real_arg, None, info_arg["properties"])
if not acc:
return False, err
return True, Error()
def check_integer(
gorilla, real_arg: Any, ideal_arg: Optional[List[Any]]
) -> Tuple[bool, Error]:
if type(real_arg) != int:
return False, Error(f"wrong type {real_arg}: not int", Err_type.TYPE_ERROR)
if ideal_arg is None:
return True, Error()
match = False
err = Error(
f"value not match: {real_arg}, ideal-opt: {ideal_arg}",
Err_type.PARA_VALUE_ERROR,
)
for ideal in ideal_arg:
if real_arg == ideal:
match = True
err = Error()
break
return match, err
def check_number(
gorilla, real_arg: Any, ideal_arg: Optional[List[Any]]
) -> Tuple[bool, Error]:
if type(real_arg) != float and type(real_arg) != int:
return False, Error(f"wrong type {real_arg}: not number", Err_type.TYPE_ERROR)
if ideal_arg is None:
return True, Error()
match = False
err = Error(
f"value not match: {real_arg}, ideal-opt: {ideal_arg}",
Err_type.PARA_VALUE_ERROR,
)
for ideal in ideal_arg:
if real_arg == ideal:
match = True
err = Error()
break
return match, err
def check_string(
gorilla, real_arg: Any, ideal_arg: Optional[List[Any]], enum: Optional[List[str]]
) -> Tuple[bool, Error]:
def standardize_string(string: Any) -> str:
if not isinstance(string, str):
return "-----Error------"
regex_string = r"[ \,\.\/\-\_\*\^]"
return re.sub(regex_string, "", string).lower().replace("'", '"')
if type(real_arg) != str:
return False, Error(f"wrong type {real_arg}: not string", Err_type.TYPE_ERROR)
match = False
err = Error(
f"value not match: {real_arg}, ideal-opt: {ideal_arg}",
Err_type.PARA_VALUE_ERROR,
)
real_arg = standardize_string(real_arg)
if ideal_arg is None:
if enum is None:
return True, Error()
else:
err.error_type = Err_type.ENUM_ERROR
for ideal in enum:
if real_arg == standardize_string(ideal):
match = True
err = Error()
break
else:
for ideal in ideal_arg:
if real_arg == standardize_string(ideal):
match = True
err = Error()
break
return match, err
def check_boolean(
gorilla, real_arg: bool, ideal_arg: Optional[List[bool]]
) -> Tuple[bool, Error]:
if type(real_arg) != bool:
return False, Error(f"wrong type {real_arg}: not bool", Err_type.TYPE_ERROR)
if ideal_arg is None:
return True, Error()
match = False
err = Error(
f"value not match: {real_arg}, ideal-opt: {ideal_arg}",
Err_type.PARA_VALUE_ERROR,
)
for ideal in ideal_arg:
if real_arg == ideal:
match = True
err = Error()
break
return match, err
def check_list(
gorilla, real_arg: List, ideal_arg: Optional[List[List]], item: Dict[str, Any]
) -> Tuple[bool, Error]:
if type(real_arg) != list:
return False, Error(f"wrong type of {real_arg}: not list.", Err_type.TYPE_ERROR)
item_type = item["type"]
if ideal_arg is None:
if item_type == "integer":
for i, integer in enumerate(real_arg):
acc, err = check_integer(gorilla, integer, None)
if not acc:
return False, err
elif item_type == "number":
for i, integer in enumerate(real_arg):
acc, err = check_number(gorilla, integer, None)
if not acc:
return False, err
elif item_type == "boolean":
for i, boolean in enumerate(real_arg):
acc, err = check_boolean(gorilla, boolean, None)
if not acc:
return False, err
elif item_type == "string":
pass
elif item_type == "array":
for i, array in enumerate(real_arg):
acc, err = check_list(gorilla, array, None, item["items"])
if not acc:
return False, err
elif item_type == "dict":
for i, dictionary in enumerate(real_arg):
acc, err = check_dict(gorilla, dictionary, None, item["properties"])
if not acc:
return False, err
return True, Error()
else:
final_err = ""
err_type = Err_type.NONE
for j, ideal in enumerate(ideal_arg):
if len(ideal) != len(real_arg):
final_err += f"[ideal {j}] wrong length of {real_arg}."
err_type = min(err_type, Err_type.PARA_VALUE_ERROR)
continue
match = True
if item_type == "integer":
for i, integer in enumerate(real_arg):
acc, err = check_integer(gorilla, integer, [ideal[i]])
if not acc:
match = False
final_err += f"[ideal {j}] {err}"
err_type = min(err_type, err.error_type)
break
elif item_type == "number":
for i, integer in enumerate(real_arg):
acc, err = check_number(gorilla, integer, [ideal[i]])
if not acc:
match = False
final_err += f"[ideal {j}] {err}"
err_type = min(err_type, err.error_type)
break
elif item_type == "boolean":
for i, boolean in enumerate(real_arg):
acc, err = check_boolean(gorilla, boolean, [ideal[i]])
if not acc:
match = False
final_err += f"[ideal {j}] {err}"
err_type = min(err_type, err.error_type)
break
elif item_type == "string":
pass
elif item_type == "array":
for i, array in enumerate(real_arg):
acc, err = check_list(gorilla, array, [ideal[i]], item["items"])
if not acc:
match = False
final_err += f"[ideal {j}] {err}"
err_type = min(err_type, err.error_type)
break
elif item_type == "dict":
for i, dictionary in enumerate(real_arg):
acc, err = check_dict(
gorilla, dictionary, [ideal[i]], item["properties"]
)
if not acc:
match = False
final_err += f"[ideal {j}] {err}"
err_type = min(err_type, err.error_type)
break
if match:
return True, Error()
return err_type == Err_type.NONE, Error(final_err, err_type)
def check_dict(
gorilla,
real_arg: Dict[str, Any],
ideal_arg: Optional[Dict[str, Any]],
properties: Dict[str, Any],
) -> Tuple[bool, Error]:
if type(real_arg) != dict:
return False, Error(f"wrong type of {real_arg}: not dict.", Err_type.TYPE_ERROR)
if ideal_arg is None:
for key in properties.keys():
if key not in real_arg:
return False, Error(f"missing key: {key}.", Err_type.PARA_KEY_ERROR)
item_type = properties[key]["type"]
if item_type == "integer":
acc, err = check_integer(gorilla, real_arg[key], None)
if not acc:
return False, err
elif item_type == "number":
acc, err = check_number(gorilla, real_arg[key], None)
if not acc:
return False, err
elif item_type == "boolean":
acc, err = check_boolean(gorilla, real_arg[key], None)
if not acc:
return False, err
elif item_type == "string":
pass
elif item_type == "array":
acc, err = check_list(
gorilla, real_arg[key], None, properties[key]["items"]
)
if not acc:
return False, err
elif item_type == "dict":
acc, err = check_dict(
gorilla, real_arg[key], None, properties[key]["properties"]
)
if not acc:
return False, err
return True, Error()
else:
final_err = ""
err_type = Err_type.NONE
for i, ideal in enumerate(ideal_arg):
match = True
for key in properties.keys():
if key not in real_arg:
match = False
final_err += f"[ideal {i}] missing key: {key}."
err_type = min(err_type, Err_type.PARA_KEY_ERROR)
break
item_type = properties[key]["type"]
if item_type == "integer":
acc, err = check_integer(gorilla, real_arg[key], [ideal[key]])
if not acc:
match = False
final_err += f"[ideal {i}] {err}"
err_type = min(err_type, err.error_type)
break
elif item_type == "number":
acc, err = check_number(gorilla, real_arg[key], [ideal[key]])
if not acc:
match = False
final_err += f"[ideal {i}] {err}"
err_type = min(err_type, err.error_type)
break
elif item_type == "boolean":
acc, err = check_boolean(gorilla, real_arg[key], [ideal[key]])
if not acc:
match = False
final_err += f"[ideal {i}] {err}"
err_type = min(err_type, err.error_type)
break
elif item_type == "string":
pass
elif item_type == "array":
acc, err = check_list(
gorilla, real_arg[key], [ideal[key]], properties[key]["items"]
)
if not acc:
match = False
final_err += f"[ideal {i}] {err}"
err_type = min(err_type, err.error_type)
break
elif item_type == "dict":
acc, err = check_dict(
gorilla,
real_arg[key],
[ideal[key]],
properties[key]["properties"],
)
if not acc:
match = False
final_err += f"[ideal {i}] {err}"
err_type = min(err_type, err.error_type)
break
if match:
return True, Error()
return err_type == Err_type.NONE, Error(final_err, err_type)
def check_acc(
model: str,
dataset: str,
gorilla: Dict,
summary: Dict,
totol_summary: Dict,
use_stag: bool,
):
"""Check the accuracy of the generated requests."""
if model not in totol_summary:
totol_summary[model] = {}
if dataset not in totol_summary[model]:
totol_summary[model][dataset] = {"use_stag": {}, "no_stag": {}}
err_types = [0] * (len(Err_type) - 1)
stag_cate = "use_stag" if use_stag else "no_stag"
valid_data_point = 0
if dataset == "BFCL_v3_simple" or dataset == "BFCL_v3_live_simple":
for item in summary:
if not item["valid_datapoint"]:
continue
else:
valid_data_point += 1
id = item["id"]
info = gorilla[id]
if len(item[stag_cate]["call"]) == 0:
summary[id][stag_cate]["success"] = False
summary[id][stag_cate]["err_type"] = "CALL_NUMBER_ERROR"
summary[id][stag_cate]["err_msg"] = "missing calling."
err_types[Err_type.CALL_NUMBER_ERROR] += 1
continue
if len(item[stag_cate]["call"]) != 1:
acc, err = (
False,
Error("wrong calling numbers.", Err_type.CALL_NUMBER_ERROR),
)
else:
acc, err = check_simple(
gorilla,
item[stag_cate]["call"][0],
info["tool"][0],
info["ideal_call"][0],
)
if not acc:
summary[id][stag_cate]["success"] = False
summary[id][stag_cate]["err_type"] = Err_type(err.error_type).name
summary[id][stag_cate]["err_msg"] = err.message
err_types[err.error_type] += 1
else:
summary[id][stag_cate]["success"] = True
summary[id][stag_cate]["err_type"] = None
summary[id][stag_cate]["err_msg"] = None
elif dataset == "BFCL_v3_multiple" or dataset == "BFCL_v3_live_multiple":
for item in summary:
if not item["valid_datapoint"]:
continue
else:
valid_data_point += 1
id = item["id"]
info = gorilla[id]
if len(item[stag_cate]["call"]) == 0:
summary[id][stag_cate]["success"] = False
summary[id][stag_cate]["err_type"] = "CALL_NUMBER_ERROR"
summary[id][stag_cate]["err_msg"] = "missing calling."
err_types[Err_type.CALL_NUMBER_ERROR] += 1
continue
if len(item[stag_cate]["call"]) != 1:
acc, err = (
False,
Error("wrong calling numbers.", Err_type.CALL_NUMBER_ERROR),
)
else:
expected_tool = None
match_tool = None
for tool in info["tool"]:
if tool["function"]["name"] == info["ideal_call"][0]["name"]:
expected_tool = tool
if (
item[stag_cate]["call"][0]["function"]["name"]
== tool["function"]["name"]
):
match_tool = tool
if match_tool != None and expected_tool != match_tool:
acc, err = (
False,
Error("wrong function selection.", Err_type.FUNC_SELECT_ERROR),
)
else:
acc, err = check_simple(
gorilla,
item[stag_cate]["call"][0],
expected_tool,
info["ideal_call"][0],
)
if not acc:
summary[id][stag_cate]["success"] = False
summary[id][stag_cate]["err_type"] = Err_type(err.error_type).name
summary[id][stag_cate]["err_msg"] = err.message
err_types[err.error_type] += 1
else:
summary[id][stag_cate]["success"] = True
summary[id][stag_cate]["err_type"] = None
summary[id][stag_cate]["err_msg"] = None
elif dataset == "BFCL_v3_parallel" or dataset == "BFCL_v3_live_parallel":
for item in summary:
if not item["valid_datapoint"]:
continue
else:
valid_data_point += 1
id = item["id"]
info = gorilla[id]
if len(item[stag_cate]["call"]) == 0:
summary[id][stag_cate]["success"] = False
summary[id][stag_cate]["err_type"] = "CALL_NUMBER_ERROR"
summary[id][stag_cate]["err_msg"] = "missing calling."
err_types[Err_type.CALL_NUMBER_ERROR] += 1
continue
if len(item[stag_cate]["call"]) != len(info["ideal_call"]):
# print("__________________________________")
# print(item[stag_cate]["call"])
# print(info["ideal_call"][0])
acc, err = (
False,
Error("wrong calling numbers.", Err_type.CALL_NUMBER_ERROR),
)
else:
matched = set()
for ideal in info["ideal_call"]:
expected_tool = None
for tool in info["tool"]:
if tool["function"]["name"] == ideal["name"]:
expected_tool = tool
break
err = Error("", err_type=Err_type.CALL_NUMBER_ERROR)
acc = False
for index, singal_call in enumerate(item[stag_cate]["call"]):
if index in matched:
continue
tmp_acc, tmp_err = check_simple(
gorilla, singal_call, expected_tool, ideal
)
if tmp_acc:
acc = True
if err.error_type < tmp_err.error_type:
err = tmp_err
if tmp_acc:
matched.add(index)
break
if not acc:
break
if not acc:
summary[id][stag_cate]["success"] = False
summary[id][stag_cate]["err_type"] = Err_type(err.error_type).name
summary[id][stag_cate]["err_msg"] = err.message
err_types[err.error_type] += 1
else:
summary[id][stag_cate]["success"] = True
summary[id][stag_cate]["err_type"] = None
summary[id][stag_cate]["err_msg"] = None
total_acc = 1
for i in range(len(Err_type) - 1):
totol_summary[model][dataset][stag_cate][Err_type(i).name] = (
err_types[i] / valid_data_point
)
total_acc -= err_types[i] / valid_data_point
totol_summary[model][dataset][stag_cate]["CORRECT_CALL"] = total_acc
def get_correct_schema_rate(
model: str,
dataset: str,
gorilla: Dict,
summary: Dict,
totol_summary: Dict,
use_stag: bool,
) -> float:
"""Get the correct schema rate of the generated requests."""
stag_cate = "use_stag" if use_stag else "no_stag"
call_number = 0
correct_schema_number = 0
for entry in summary:
if not entry["valid_datapoint"]:
continue
output = entry[stag_cate]["output"]
parsed_calls = _parse_calls_for_model(model, output, entry["tools"])
for call in parsed_calls:
call_number += 1
err_list = []
for tool in entry["tools"]:
acc, err = check_simple_schema(gorilla, call, tool)
err_list.append(err)
if acc or err.error_type == Err_type.PARA_VALUE_ERROR:
correct_schema_number += 1
break
if "correct_schema_rate" not in totol_summary[model][dataset]:
totol_summary[model][dataset]["correct_schema_rate"] = {}
totol_summary[model][dataset]["correct_schema_rate"][stag_cate] = (
correct_schema_number / call_number
)
def main(args: argparse.Namespace):
"""Main benchmark entrance."""
models = []
datasets = []
if args.dataset == "ALL":
datasets = SUPPORTED_DATASET
datasets.pop(-1)
else:
datasets.append(args.dataset)
if args.model == "ALL":
models = SUPPORTED_MODEL
models.pop(-1)
else:
models.append(args.model)
total_summary = {}
for model in models:
for dataset in datasets:
result_dir = f"{args.output_root}/{model}/{dataset}"
if not os.path.exists(
f"{result_dir}/use_stag/result.json"
) or not os.path.exists(f"{result_dir}/no_stag/result.json"):
continue
with open(
f"{result_dir}/use_stag/result.json", mode="r", encoding="utf-8"
) as file:
use_stag_result = json.load(file)
with open(
f"{result_dir}/no_stag/result.json", mode="r", encoding="utf-8"
) as file:
no_stag_result = json.load(file)
with open(
f"{args.dataset_path}/{dataset}.json", mode="r", encoding="utf-8"
) as file:
gorilla = json.load(file)
print(f"Begin checking {model} on {dataset}...")
summary = []
for i in range(len(use_stag_result)):
assert i == use_stag_result[i]["id"]
assert i == no_stag_result[i]["id"]
assert i == gorilla[i]["id"]
if not valid_data_point(gorilla[i]["tool"], gorilla[i]["ideal_call"]):
summary.append(
{
"id": i,
"valid_datapoint": False,
"no_stag": None,
"use_stag": None,
"input": None,
"tools": None,
"expected": None,
}
)
continue
summary.append(
{
"id": i,
"valid_datapoint": True,
"no_stag": {
"output": no_stag_result[i]["output"],
"call": (
no_stag_result[i]["call"]
if ("call" in no_stag_result[i])
else []
),
},
"use_stag": {
"output": use_stag_result[i]["output"],
"call": (
use_stag_result[i]["call"]
if ("call" in use_stag_result[i])
else []
),
},
"input": gorilla[i]["question"],
"tools": gorilla[i]["tool"],
"expected": gorilla[i]["ideal_call"],
}
)
check_acc(model, dataset, gorilla, summary, total_summary, False)
check_acc(model, dataset, gorilla, summary, total_summary, True)
get_correct_schema_rate(
model, dataset, gorilla, summary, total_summary, False
)
get_correct_schema_rate(
model, dataset, gorilla, summary, total_summary, True
)
if not os.path.exists(f"{args.final_root}/{model}/{dataset}"):
os.makedirs(f"{args.final_root}/{model}/{dataset}")
with open(f"{args.final_root}/{model}/{dataset}/summary.json", "w") as f:
json.dump(summary, f, indent=4)
if not os.path.exists(args.final_root):
os.makedirs(args.final_root)
with open(f"{args.final_root}/summary.json", "w") as f:
json.dump(total_summary, f, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Tool-calling accuracy result checker")
parser.add_argument(
"--dataset",
type=str,
choices=SUPPORTED_DATASET,
help=f"The benchmark dataset kind. Supporting {SUPPORTED_DATASET}",
)
parser.add_argument(
"--model",
type=str,
help=f'The benchmark model kind, or "ALL" (supported defaults: {SUPPORTED_MODEL}).',
)
parser.add_argument(
"--dataset-path",
type=str,
required=True,
help="The dataset file path.",
)
parser.add_argument(
"--output-root",
type=str,
required=True,
help="The root of the raw output file.",
)
parser.add_argument(
"--final-root",
type=str,
required=True,
help="The root of the summary file.",
)
args = parser.parse_args()
main(args)
"""Benchmark dataset classes."""
import argparse
import json
import os
import requests
import random
from datetime import datetime
import re
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd # pylint: disable=import-error
from datasets import load_dataset # pylint: disable=import-error
from transformers import AutoTokenizer # pylint: disable=import-error
from openai_protocol import (
ChatCompletionMessage,
ChatCompletionRequest,
DebugConfig,
)
from request_record import GroupedRequestRecord, Metrics, RequestRecord
from xgrammar import get_model_structural_tag
class Dataset: # pylint: disable=too-few-public-methods
"""The dataset base class."""
# We set a truncation limit of 100k.
truncate_length = int(1e5)
# For some that datasets (e.g., dataset that has shared common prefix),
# we need fake warmup requests to avoid prefilling common prefixes to the engine.
require_fake_warmup: bool = False
# Whether the dataset contains timestamps already.
# If the dataset comes with timestamps, the benchmark can just replay
# the requests according to their timestamps.
timestamp_available: bool = False
def generate_request_records(
self,
input_len: Optional[int],
output_len: Optional[int],
input_len_std: float = 0.0,
output_len_std: float = 0.0,
) -> List[RequestRecord]:
"""Get the raw unprocessed request records of the dataset."""
raise NotImplementedError()
GORILLA_TO_OPENAPI = {
"integer": "integer",
"number": "number",
"float": "number",
"string": "string",
"boolean": "boolean",
"bool": "boolean",
"array": "array",
"list": "array",
"dict": "object",
"object": "object",
"tuple": "array",
"any": "string",
"byte": "integer",
"short": "integer",
"long": "integer",
"double": "number",
"char": "string",
"ArrayList": "array",
"Array": "array",
"HashMap": "object",
"Hashtable": "object",
"Queue": "array",
"Stack": "array",
"Any": "string",
"String": "string",
"Bigint": "integer",
}
from enum import IntEnum
class Err_type(IntEnum):
FORMAT_ERROR = 0
CALL_NUMBER_ERROR = 1
FUNC_NAME_ERROR = 2
PARA_KEY_ERROR = 3
TYPE_ERROR = 4
ENUM_ERROR = 5
PARA_VALUE_ERROR = 6
NONE = 7
class Error:
def __init__(self, message: str = "", err_type: Err_type = Err_type.NONE):
self.message = message
self.error_type = err_type
class GorillaDataset(Dataset): # pylint: disable=too-few-public-methods
"""The dataset class for Gorilla dataset.
Reference: https://github.com/ShishirPatil/gorilla
"""
def __init__(
self,
dataset: str,
dataset_path: str,
tokenizer: AutoTokenizer,
use_stag: bool,
api_endpoint: str,
model: str,
) -> None:
self.tokenizer = tokenizer
self.use_stag = use_stag
self.api_endpoint = api_endpoint
self.model = model
self.gorilla_data = []
base_url = "https://raw.githubusercontent.com/ShishirPatil/gorilla/main/berkeley-function-call-leaderboard/data"
id = 0
dataset_file = f"{dataset_path}/{dataset}.json"
if os.path.exists(dataset_file):
with open(dataset_file, mode="r", encoding="utf-8") as file:
self.gorilla_data = json.load(file)
else:
function_url = f"{base_url}/{dataset}.json"
answer_url = f"{base_url}/possible_answer/{dataset}.json"
print(f"Downloading {dataset}.json from GitHub...")
functions_data = []
answers_data = []
try:
function_response = requests.get(function_url)
function_response.raise_for_status()
function_text = function_response.text
for line in function_text.strip().split("\n"):
if line.strip():
try:
functions_data.append(json.loads(line))
except json.JSONDecodeError as e:
print(f"Error parsing function line in {dataset}.json: {e}")
answer_response = requests.get(answer_url)
answer_response.raise_for_status()
answer_text = answer_response.text
for line in answer_text.strip().split("\n"):
if line.strip():
try:
answers_data.append(json.loads(line))
except json.JSONDecodeError as e:
print(f"Error parsing answer line in {dataset}.json: {e}")
print(
f"Successfully downloaded {dataset}.json: {len(functions_data)} functions, {len(answers_data)} answers"
)
except requests.RequestException as e:
print(f"Error downloading {dataset}.json: {e}")
functions_data = []
answers_data = []
if not functions_data or not answers_data:
print(f"Skipping {dataset}.json - failed to download data")
return
print(f"Processing {dataset}.json...")
answers_by_id = {item["id"]: item for item in answers_data}
for item in functions_data:
item_id = item["id"]
question = item["question"][0]
if item_id not in answers_by_id:
print(f"Warning: No answer found for item {item_id}")
continue
if "function" not in item or not item["function"]:
print(f"Warning: No function definition for item {item_id}")
continue
tool = [
{"type": "function", "function": func} for func in item["function"]
]
self.map_type_values(tool)
answer = answers_by_id[item_id]
if "ground_truth" not in answer or not answer["ground_truth"]:
print(f"Warning: No ground truth for item {item_id}")
continue
ideal_call = []
for ground_truth in answer["ground_truth"]:
function_name = list(ground_truth.keys())[0]
params = ground_truth[function_name]
ideal_call.append({"name": function_name, "arguments": params})
self.gorilla_data.append(
{
"id": id,
"question": question,
"tool": tool,
"ideal_call": ideal_call,
"source": f"{dataset}.json",
}
)
id += 1
with open(dataset_file, mode="w", encoding="utf-8") as file:
json.dump(self.gorilla_data, file, ensure_ascii=False, indent=4)
if self.tokenizer is not None:
for item in self.gorilla_data:
num_tokens = 0
for message in item["question"]:
num_tokens += len(
tokenizer.encode(message["content"], add_special_tokens=False)
)
item["num_tokens"] = num_tokens
def gen_warmup_dataset(self):
"""Generate a warmup dataset for the benchmark."""
length = len(self.gorilla_data)
for i in range(length):
self.gorilla_data.append(self.gorilla_data[i].copy())
def generate_request_records(
self,
input_len: Optional[int],
output_len: Optional[int],
input_len_std: float = 0.0,
output_len_std: float = 0.0,
) -> List[RequestRecord]:
request_records = []
for entry in self.gorilla_data:
is_llama3_model = "Llama-3" in self.model
is_qwen36_model = "Qwen3.6" in self.model
if output_len is not None:
output_length = output_len
else:
output_length = 1024
if self.use_stag:
if is_llama3_model:
response_format = {
"type": "structural_tag",
"tags": [
{
"begin": '{{"name": "{func_name}", "parameters":'.format(
func_name=tool["function"]["name"]
),
"schema": json.dumps(
{
"properties": tool["function"]["parameters"][
"properties"
],
"required": tool["function"]["parameters"][
"required"
],
"type": tool["function"]["parameters"]["type"],
}
),
"end": "}",
}
for tool in entry["tool"]
],
"triggers": ['{"name":'],
}
elif is_qwen36_model:
# ABC uses XML function-call format. Structural tags are filled outside.
response_format = get_model_structural_tag(
model="qwen_3_5",
tools=entry["tool"],
tool_choice="auto",
reasoning=True
).model_dump()
else:
response_format = {
"type": "structural_tag",
"tags": [
{
"begin": '<tool_call>\n{{"name": "{func_name}", "arguments":'.format(
func_name=tool["function"]["name"]
),
"schema": json.dumps(
{
"properties": tool["function"]["parameters"][
"properties"
],
"required": tool["function"]["parameters"][
"required"
],
"type": tool["function"]["parameters"]["type"],
}
),
"end": "}\n</tool_call>",
}
for tool in entry["tool"]
],
"triggers": ["<tool_call>"],
}
else:
response_format = {
"type": "text",
}
if "Llama-3.1" in self.model:
messages = [
ChatCompletionMessage(
content="",
role="system",
),
ChatCompletionMessage(
content=(
"Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\n"
'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.\n\n'
),
role="user",
),
]
for tool in entry["tool"]:
messages[1].content += f"{json.dumps(tool)}\n\n"
for message in entry["question"]:
if message["role"] == "system":
messages[0].content += message["content"]
else:
messages[1].content += message["content"]
elif is_qwen36_model:
tools_str = ""
for tool in entry["tool"]:
tools_str += f"{json.dumps(tool, indent=4)}\n"
messages = [
ChatCompletionMessage(
content=(
"# Tools\n\n"
"You have access to the following functions:\n\n"
f"<tools>\n{tools_str}</tools>\n\n"
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n"
"<tool_call>\n"
"<function=example_function_name>\n"
"<parameter=example_parameter_1>\n"
"value_1\n"
"</parameter>\n"
"<parameter=example_parameter_2>\n"
"This is the value for the second parameter\n"
"that can span\n"
"multiple lines\n"
"</parameter>\n"
"</function>\n"
"</tool_call>\n\n"
"<IMPORTANT>\n"
"Reminder:\n"
"- Function calls MUST follow the specified format: "
"an inner <function=...></function> block must be nested within "
"<tool_call></tool_call> XML tags\n"
"- Required parameters MUST be specified\n"
"- You may provide optional reasoning for your function call in natural language "
"BEFORE the function call, but NOT after\n"
"- If there is no function call available, answer the question like normal with "
"your current knowledge and do not tell the user about function calls\n"
"</IMPORTANT>"
),
role="system",
),
ChatCompletionMessage(content="", role="user"),
]
for message in entry["question"]:
if message["role"] == "system":
messages[0].content += f"\n\n{message['content']}"
else:
messages[1].content += message["content"]
elif "Qwen2" in self.model:
tools_str = ""
for tool in entry["tool"]:
tools_str += f"{json.dumps(tool, indent=4)}\n"
messages = [
ChatCompletionMessage(
content=(
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n"
"# Tools\n\n"
"You may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n"
f"<tools>\n{tools_str}</tools>\n\n"
"For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n"
'<tool_call>\n{"name": <function-name>, "arguments": <args-json-object>}\n</tool_call>'
),
role="system",
),
ChatCompletionMessage(content="", role="user"),
]
for message in entry["question"]:
if message["role"] == "system":
messages[0].content += message["content"]
else:
messages[1].content += message["content"]
else:
messages = [
ChatCompletionMessage(
content=(
"Tool Instructions:"
"You have access to the following tool functions:"
f"{entry['tool']}"
"If a you choose to call a function, you should ONLY reply in the following format:"
'`{"name": func_name, "parameters": parameters(JSON dict)}`'
"Here is an example,"
'`{"name": "get_time", "parameters": {"location": "Pittsburgh"}}}}`'
"Reminder:"
"- Function calls MUST follow the specified format"
"- Required parameters MUST be specified"
"- You should not repeat or miss the call"
"- You should response with at least one function calling"
),
role="system",
)
]
for message in entry["question"]:
if message["role"] == "system":
messages[0].content += message["content"]
else:
messages.append(
ChatCompletionMessage(
content=message["content"], role=message["role"]
)
)
request_records.append(
RequestRecord(
request_id=entry["id"],
chat_cmpl=ChatCompletionRequest(
messages=messages,
response_format=response_format,
model="",
max_tokens=output_length,
debug_config=DebugConfig(grammar_execution_mode="constraint"),
),
metrics=Metrics(
success=False,
start_time=0,
finish_time=0,
end_to_end_latency_s=0,
input_tokens=entry["num_tokens"],
),
)
)
return request_records
SUPPORTED_DATASET = [
"BFCL_v3_simple",
"BFCL_v3_multiple",
"BFCL_v3_parallel",
"BFCL_v3_live_simple",
"BFCL_v3_live_multiple",
"BFCL_v3_live_parallel",
]
def create_dataset( # pylint: disable=too-many-return-statements,too-many-branches
args: argparse.Namespace, tokenizer: AutoTokenizer
) -> Dataset:
"""Create a dataset instance with regard to the specified dataset kind and file path."""
if args.dataset_path is not None and not isinstance(args.dataset_path, str):
raise TypeError(
f"Invalid dataset path {args.dataset_path}. Please use a string."
)
if args.dataset in SUPPORTED_DATASET:
if args.dataset_path is None:
raise ValueError(
"Gorilla dataset requires dataset path. "
'Please specify it with "--dataset-path".'
)
assert (
args.apply_chat_template is False
), "Gorilla dataset does not support applying chat template"
return GorillaDataset(
args.dataset,
args.dataset_path,
tokenizer,
args.use_stag,
args.api_endpoint,
args.model,
)
raise ValueError(f"Unrecognized dataset {args.dataset}")
"""Benchmark dataset classes."""
import argparse
import json
import os
import requests
import random
from datetime import datetime
import re
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd # pylint: disable=import-error
from datasets import load_dataset # pylint: disable=import-error
from transformers import AutoTokenizer # pylint: disable=import-error
from openai_protocol import (
ChatCompletionMessage,
ChatCompletionRequest,
DebugConfig,
)
from request_record import GroupedRequestRecord, Metrics, RequestRecord
from xgrammar import get_model_structural_tag
class Dataset: # pylint: disable=too-few-public-methods
"""The dataset base class."""
# We set a truncation limit of 100k.
truncate_length = int(1e5)
# For some that datasets (e.g., dataset that has shared common prefix),
# we need fake warmup requests to avoid prefilling common prefixes to the engine.
require_fake_warmup: bool = False
# Whether the dataset contains timestamps already.
# If the dataset comes with timestamps, the benchmark can just replay
# the requests according to their timestamps.
timestamp_available: bool = False
def generate_request_records(
self,
input_len: Optional[int],
output_len: Optional[int],
input_len_std: float = 0.0,
output_len_std: float = 0.0,
) -> List[RequestRecord]:
"""Get the raw unprocessed request records of the dataset."""
raise NotImplementedError()
GORILLA_TO_OPENAPI = {
"integer": "integer",
"number": "number",
"float": "number",
"string": "string",
"boolean": "boolean",
"bool": "boolean",
"array": "array",
"list": "array",
"dict": "object",
"object": "object",
"tuple": "array",
"any": "string",
"byte": "integer",
"short": "integer",
"long": "integer",
"double": "number",
"char": "string",
"ArrayList": "array",
"Array": "array",
"HashMap": "object",
"Hashtable": "object",
"Queue": "array",
"Stack": "array",
"Any": "string",
"String": "string",
"Bigint": "integer",
}
from enum import IntEnum
class Err_type(IntEnum):
FORMAT_ERROR = 0
CALL_NUMBER_ERROR = 1
FUNC_NAME_ERROR = 2
PARA_KEY_ERROR = 3
TYPE_ERROR = 4
ENUM_ERROR = 5
PARA_VALUE_ERROR = 6
NONE = 7
class Error:
def __init__(self, message: str = "", err_type: Err_type = Err_type.NONE):
self.message = message
self.error_type = err_type
class GorillaDataset(Dataset): # pylint: disable=too-few-public-methods
"""The dataset class for Gorilla dataset.
Reference: https://github.com/ShishirPatil/gorilla
"""
def __init__(
self,
dataset: str,
dataset_path: str,
tokenizer: AutoTokenizer,
use_stag: bool,
api_endpoint: str,
model: str,
) -> None:
self.tokenizer = tokenizer
self.use_stag = use_stag
self.api_endpoint = api_endpoint
self.model = model
self.gorilla_data = []
base_url = "https://raw.githubusercontent.com/ShishirPatil/gorilla/main/berkeley-function-call-leaderboard/data"
id = 0
dataset_file = f"{dataset_path}/{dataset}.json"
if os.path.exists(dataset_file):
with open(dataset_file, mode="r", encoding="utf-8") as file:
self.gorilla_data = json.load(file)
else:
function_url = f"{base_url}/{dataset}.json"
answer_url = f"{base_url}/possible_answer/{dataset}.json"
print(f"Downloading {dataset}.json from GitHub...")
functions_data = []
answers_data = []
try:
function_response = requests.get(function_url)
function_response.raise_for_status()
function_text = function_response.text
for line in function_text.strip().split("\n"):
if line.strip():
try:
functions_data.append(json.loads(line))
except json.JSONDecodeError as e:
print(f"Error parsing function line in {dataset}.json: {e}")
answer_response = requests.get(answer_url)
answer_response.raise_for_status()
answer_text = answer_response.text
for line in answer_text.strip().split("\n"):
if line.strip():
try:
answers_data.append(json.loads(line))
except json.JSONDecodeError as e:
print(f"Error parsing answer line in {dataset}.json: {e}")
print(
f"Successfully downloaded {dataset}.json: {len(functions_data)} functions, {len(answers_data)} answers"
)
except requests.RequestException as e:
print(f"Error downloading {dataset}.json: {e}")
functions_data = []
answers_data = []
if not functions_data or not answers_data:
print(f"Skipping {dataset}.json - failed to download data")
return
print(f"Processing {dataset}.json...")
answers_by_id = {item["id"]: item for item in answers_data}
for item in functions_data:
item_id = item["id"]
question = item["question"][0]
if item_id not in answers_by_id:
print(f"Warning: No answer found for item {item_id}")
continue
if "function" not in item or not item["function"]:
print(f"Warning: No function definition for item {item_id}")
continue
tool = [
{"type": "function", "function": func} for func in item["function"]
]
self.map_type_values(tool)
answer = answers_by_id[item_id]
if "ground_truth" not in answer or not answer["ground_truth"]:
print(f"Warning: No ground truth for item {item_id}")
continue
ideal_call = []
for ground_truth in answer["ground_truth"]:
function_name = list(ground_truth.keys())[0]
params = ground_truth[function_name]
ideal_call.append({"name": function_name, "arguments": params})
self.gorilla_data.append(
{
"id": id,
"question": question,
"tool": tool,
"ideal_call": ideal_call,
"source": f"{dataset}.json",
}
)
id += 1
with open(dataset_file, mode="w", encoding="utf-8") as file:
json.dump(self.gorilla_data, file, ensure_ascii=False, indent=4)
if self.tokenizer is not None:
for item in self.gorilla_data:
num_tokens = 0
for message in item["question"]:
num_tokens += len(
tokenizer.encode(message["content"], add_special_tokens=False)
)
item["num_tokens"] = num_tokens
def gen_warmup_dataset(self):
"""Generate a warmup dataset for the benchmark."""
length = len(self.gorilla_data)
for i in range(length):
self.gorilla_data.append(self.gorilla_data[i].copy())
def generate_request_records(
self,
input_len: Optional[int],
output_len: Optional[int],
input_len_std: float = 0.0,
output_len_std: float = 0.0,
) -> List[RequestRecord]:
request_records = []
for entry in self.gorilla_data:
is_llama3_model = "Llama-3" in self.model
is_qwen36_model = "Qwen3.6" in self.model
if output_len is not None:
output_length = output_len
else:
output_length = 1024
if self.use_stag:
if is_llama3_model:
response_format = {
"type": "structural_tag",
"tags": [
{
"begin": '{{"name": "{func_name}", "parameters":'.format(
func_name=tool["function"]["name"]
),
"schema": json.dumps(
{
"properties": tool["function"]["parameters"][
"properties"
],
"required": tool["function"]["parameters"][
"required"
],
"type": tool["function"]["parameters"]["type"],
}
),
"end": "}",
}
for tool in entry["tool"]
],
"triggers": ['{"name":'],
}
elif is_qwen36_model:
# ABC uses XML function-call format. Structural tags are filled outside.
response_format = get_model_structural_tag(
model="qwen_3_5",
tools=entry["tool"],
tool_choice="auto",
reasoning=True
).model_dump()
else:
response_format = {
"type": "structural_tag",
"tags": [
{
"begin": '<tool_call>\n{{"name": "{func_name}", "arguments":'.format(
func_name=tool["function"]["name"]
),
"schema": json.dumps(
{
"properties": tool["function"]["parameters"][
"properties"
],
"required": tool["function"]["parameters"][
"required"
],
"type": tool["function"]["parameters"]["type"],
}
),
"end": "}\n</tool_call>",
}
for tool in entry["tool"]
],
"triggers": ["<tool_call>"],
}
else:
response_format = {
"type": "text",
}
if "Llama-3.1" in self.model:
messages = [
ChatCompletionMessage(
content="",
role="system",
),
ChatCompletionMessage(
content=(
"Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\n"
'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.\n\n'
),
role="user",
),
]
for tool in entry["tool"]:
messages[1].content += f"{json.dumps(tool)}\n\n"
for message in entry["question"]:
if message["role"] == "system":
messages[0].content += message["content"]
else:
messages[1].content += message["content"]
elif is_qwen36_model:
tools_str = ""
for tool in entry["tool"]:
tools_str += f"{json.dumps(tool, indent=4)}\n"
messages = [
ChatCompletionMessage(
content=(
"# Tools\n\n"
"You have access to the following functions:\n\n"
f"<tools>\n{tools_str}</tools>\n\n"
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n"
"<tool_call>\n"
"<function=example_function_name>\n"
"<parameter=example_parameter_1>\n"
"value_1\n"
"</parameter>\n"
"<parameter=example_parameter_2>\n"
"This is the value for the second parameter\n"
"that can span\n"
"multiple lines\n"
"</parameter>\n"
"</function>\n"
"</tool_call>\n\n"
"<IMPORTANT>\n"
"Reminder:\n"
"- Function calls MUST follow the specified format: "
"an inner <function=...></function> block must be nested within "
"<tool_call></tool_call> XML tags\n"
"- Required parameters MUST be specified\n"
"- You may provide optional reasoning for your function call in natural language "
"BEFORE the function call, but NOT after\n"
"- If there is no function call available, answer the question like normal with "
"your current knowledge and do not tell the user about function calls\n"
"</IMPORTANT>"
),
role="system",
),
ChatCompletionMessage(content="", role="user"),
]
for message in entry["question"]:
if message["role"] == "system":
messages[0].content += f"\n\n{message['content']}"
else:
messages[1].content += message["content"]
elif "Qwen2" in self.model:
tools_str = ""
for tool in entry["tool"]:
tools_str += f"{json.dumps(tool, indent=4)}\n"
messages = [
ChatCompletionMessage(
content=(
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n"
"# Tools\n\n"
"You may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n"
f"<tools>\n{tools_str}</tools>\n\n"
"For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n"
'<tool_call>\n{"name": <function-name>, "arguments": <args-json-object>}\n</tool_call>'
),
role="system",
),
ChatCompletionMessage(content="", role="user"),
]
for message in entry["question"]:
if message["role"] == "system":
messages[0].content += message["content"]
else:
messages[1].content += message["content"]
else:
messages = [
ChatCompletionMessage(
content=(
"Tool Instructions:"
"You have access to the following tool functions:"
f"{entry['tool']}"
"If a you choose to call a function, you should ONLY reply in the following format:"
'`{"name": func_name, "parameters": parameters(JSON dict)}`'
"Here is an example,"
'`{"name": "get_time", "parameters": {"location": "Pittsburgh"}}}}`'
"Reminder:"
"- Function calls MUST follow the specified format"
"- Required parameters MUST be specified"
"- You should not repeat or miss the call"
"- You should response with at least one function calling"
),
role="system",
)
]
for message in entry["question"]:
if message["role"] == "system":
messages[0].content += message["content"]
else:
messages.append(
ChatCompletionMessage(
content=message["content"], role=message["role"]
)
)
request_records.append(
RequestRecord(
request_id=entry["id"],
chat_cmpl=ChatCompletionRequest(
messages=messages,
response_format=response_format,
model="",
max_tokens=output_length,
debug_config=DebugConfig(grammar_execution_mode="constraint"),
),
metrics=Metrics(
success=False,
start_time=0,
finish_time=0,
end_to_end_latency_s=0,
input_tokens=entry["num_tokens"],
),
)
)
return request_records
SUPPORTED_DATASET = [
"BFCL_v3_simple",
"BFCL_v3_multiple",
"BFCL_v3_parallel",
"BFCL_v3_live_simple",
"BFCL_v3_live_multiple",
"BFCL_v3_live_parallel",
]
def create_dataset( # pylint: disable=too-many-return-statements,too-many-branches
args: argparse.Namespace, tokenizer: AutoTokenizer
) -> Dataset:
"""Create a dataset instance with regard to the specified dataset kind and file path."""
if args.dataset_path is not None and not isinstance(args.dataset_path, str):
raise TypeError(
f"Invalid dataset path {args.dataset_path}. Please use a string."
)
if args.dataset in SUPPORTED_DATASET:
if args.dataset_path is None:
raise ValueError(
"Gorilla dataset requires dataset path. "
'Please specify it with "--dataset-path".'
)
assert (
args.apply_chat_template is False
), "Gorilla dataset does not support applying chat template"
return GorillaDataset(
args.dataset,
args.dataset_path,
tokenizer,
args.use_stag,
args.api_endpoint,
args.model,
)
raise ValueError(f"Unrecognized dataset {args.dataset}")
"""Benchmark request processors."""
import argparse
import asyncio
import concurrent.futures
import copy
import os
import random
import time
from typing import Any, Callable, Dict, List, Optional
import numpy as np
import requests
from tqdm import tqdm
from transformers import AutoTokenizer # pylint: disable=import-error
from api_endpoint import APIEndPoint
from dataset import Dataset
from openai_protocol import ChatCompletionMessage, ChatCompletionRequest, DebugConfig
from request_record import GroupedRequestRecord, RequestRecord
import logging
logger = logging.getLogger(__name__)
class RequestProcessor: # pylint: disable=too-few-public-methods
"""The request processor base class.
Each processor can take a list of RequestRecord, applying the process,
and returning the processed RequestRecord in the end.
"""
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
raise NotImplementedError()
class LogMessage(RequestProcessor): # pylint: disable=too-few-public-methods
"""The processor that prints the logger message."""
def __init__(self, message: str) -> None:
self.message = message
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
logger.info(self.message)
return request_records
class SampleRequests(RequestProcessor): # pylint: disable=too-few-public-methods
"""The processor that samples requests out from the given request list."""
def __init__(self, num_requests: int, take_first_x_requests: bool = True) -> None:
self.num_requests = num_requests
# If `take_first_x_requests` is True, the first `num_requests` requests
# are returned and sampling will not happen.
self.take_first_x_requests = take_first_x_requests
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
assert len(request_records) > 0, "Empty input request record."
# We expect the input request records to be all grouped or all plain.
if isinstance(request_records[0], GroupedRequestRecord):
assert all(
isinstance(record, GroupedRequestRecord) for record in request_records
)
return self._sample_from_grouped_request_records(request_records)
assert all(
not isinstance(record, GroupedRequestRecord) for record in request_records
)
return self._sample_from_plain_request_records(request_records)
def _sample_from_plain_request_records(
self, request_records: List[RequestRecord]
) -> List[RequestRecord]:
samples: List[RequestRecord] = []
if self.take_first_x_requests:
if len(request_records) < self.num_requests:
raise ValueError(
f"Insufficient requests. Requiring {self.num_requests} requests "
f"but only {len(request_records)} are available."
)
samples = copy.deepcopy(list(request_records[: self.num_requests]))
else:
while len(samples) < self.num_requests:
# Create a new list so that the in-place shuffle does not mutate the input list.
records = list(request_records)
random.shuffle(records)
samples += copy.deepcopy(records)
samples = samples[: self.num_requests]
for i, record in enumerate(samples):
record.request_id = i
return samples
def _sample_from_grouped_request_records(
self, grouped_request_records: List[GroupedRequestRecord]
) -> List[RequestRecord]:
num_total_available_requests = sum(
len(record.records) for record in grouped_request_records
)
if self.num_requests > num_total_available_requests:
raise ValueError(
"Due to the existence of shared common prefixes, we do not allow "
"benchmarking with requests more than the available requests in the dataset. "
f"The required number of requests {self.num_requests} exceeds the "
f"number of total available requests {num_total_available_requests}."
)
# Create a new list so that the in-place shuffle does not mutate the input list.
records = list(grouped_request_records)
if not self.take_first_x_requests:
random.shuffle(records)
remaining = self.num_requests
samples: List[RequestRecord] = []
for grouped_request_record in grouped_request_records:
num_used_requests = min(len(grouped_request_record.records), remaining)
samples += grouped_request_record.records[:num_used_requests]
remaining -= num_used_requests
if remaining == 0:
break
for i, record in enumerate(samples):
record.request_id = i
return samples
class AttachModelName(RequestProcessor): # pylint: disable=too-few-public-methods
"""The processor that attaches model name to requests."""
def __init__(self, model: str) -> None:
self.model = model
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
for request_record in request_records:
request_record.chat_cmpl.model = self.model
return request_records
class AttachRequestRateTimestamp(
RequestProcessor
): # pylint: disable=too-few-public-methods
"""The processor that applies timestamps to the requests."""
def __init__(self, request_rate: np.float32) -> None:
self.request_rate = request_rate
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
timestamp = 0.0
for request_record in request_records:
assert (
request_record.timestamp is None
), "The request record already has a timestamp"
request_record.timestamp = timestamp
timestamp += float(np.random.exponential(1.0 / self.request_rate))
return request_records
class AttachExecutionFeature(
RequestProcessor
): # pylint: disable=too-few-public-methods
"""The processor that attaches execution features to all requests"""
def __init__(self, exec_feature: Dict[str, Any]) -> None:
self.exec_feature = exec_feature
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
for request_record in request_records:
assert request_record.metrics is not None
request_record.metrics.exec_feature = self.exec_feature
return request_records
class AttachStreamFlag(RequestProcessor): # pylint: disable=too-few-public-methods
"""The processor that attaches the stream flag to the requests."""
def __init__(self, stream: Optional[bool]) -> None:
self.stream = stream
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
if self.stream is None:
return request_records
for request_record in request_records:
request_record.chat_cmpl.stream = self.stream
return request_records
class AttachSamplingOptions(RequestProcessor): # pylint: disable=too-few-public-methods
"""The processor that attaches the stream flag to the requests."""
def __init__(self, temperature: float, top_p: float, ignore_eos: bool) -> None:
self.temperature = temperature
self.top_p = top_p
self.ignore_eos = ignore_eos
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
for request_record in request_records:
request_record.chat_cmpl.temperature = self.temperature
request_record.chat_cmpl.top_p = self.top_p
request_record.chat_cmpl.frequency_penalty = 0.0
request_record.chat_cmpl.presence_penalty = 0.0
request_record.chat_cmpl.tool_choice = "none"
if self.ignore_eos:
request_record.chat_cmpl.debug_config = DebugConfig(ignore_eos=True)
return request_records
class ScaleTimestamp(RequestProcessor): # pylint: disable=too-few-public-methods
"""Scale the timestamp of requests by the given scale factor."""
def __init__(self, timestamp_scale: float):
self.timestamp_scale = timestamp_scale
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
for request_record in request_records:
if request_record.timestamp is None:
raise ValueError(
f"The timestamp of request {request_record} has not been initialized."
)
request_record.timestamp *= self.timestamp_scale
return request_records
class MetricAnalyzer(RequestProcessor): # pylint: disable=too-few-public-methods
"""The processor that analyzes the raw benchmark results and computes more detailed metrics."""
def __init__(self, tokenizer: AutoTokenizer) -> None:
self.tokenizer = tokenizer
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
updated_records = []
for request_record in request_records:
metrics = request_record.metrics
if not metrics.success:
assert request_record.error_msg is not None
continue
metrics.output_tokens = len(
self.tokenizer.encode(
request_record.output_str, add_special_tokens=False
)
)
first_chunk_output_tokens = len(
self.tokenizer.encode(
request_record.first_chunk_output_str, add_special_tokens=False
)
)
if metrics.output_tokens <= first_chunk_output_tokens:
metrics.success = False
request_record.error_msg = (
f"Total output token num ({metrics.output_tokens}) equals "
f'the first chunk output token. Output text "{request_record.output_str}", '
f'first chunk output text "{request_record.first_chunk_output_str}"'
)
continue
assert metrics.input_tokens > 0, "Invalid prompt tokens"
metrics.inter_token_latency_s = (
metrics.end_to_end_latency_s / metrics.output_tokens
)
if metrics.time_to_first_token_s is None:
metrics.time_to_first_token_s = 0
metrics.time_per_output_token_s = (
metrics.end_to_end_latency_s - metrics.time_to_first_token_s
) / (metrics.output_tokens - first_chunk_output_tokens)
updated_records.append(request_record)
return updated_records
class WarmupAndRun(
RequestProcessor
): # pylint: disable=too-few-public-methods,line-too-long
"""The processor that runs warmup first and then runs the benchmark with the given pipeline."""
def __init__( # pylint: disable=too-many-arguments
self,
num_warmup_requests: int,
num_benchmark_requests: int,
pipeline: RequestProcessor,
cuda_profile_url: Optional[str],
fake_warmup: bool = False,
) -> None:
self.num_warmup_requests = num_warmup_requests
self.num_benchmark_requests = num_benchmark_requests
self.pipeline = pipeline
self.cuda_profile_url = cuda_profile_url
self.fake_warmup = fake_warmup
def generate_fake_warmup_requests( # pylint: disable=missing-function-docstring
self, num_warmup_requests: int, example_request: RequestRecord
) -> List[RequestRecord]:
records = []
for _ in range(num_warmup_requests):
record = copy.deepcopy(example_request)
record.chat_cmpl = ChatCompletionRequest(
messages=[
{
"role": "user",
"content": "Please output arbitrary coherent sentences. Do not output eos token.", # pylint: disable=line-too-long
}
],
model="",
max_tokens=128,
)
records.append(record)
return records
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
# Warmup
if self.fake_warmup:
assert len(request_records) == self.num_benchmark_requests
benchmark_requests = request_records
example_request = benchmark_requests[0]
warmup_requests = self.generate_fake_warmup_requests(
self.num_warmup_requests, example_request=example_request
)
else:
assert (
len(request_records)
== self.num_warmup_requests + self.num_benchmark_requests
)
benchmark_requests = request_records[: -self.num_warmup_requests]
warmup_requests = request_records[-self.num_warmup_requests :]
for request_record in warmup_requests:
request_record.timestamp = (
0 if request_record.timestamp is not None else None
)
warmup_requests = self._process_warmup_requests(warmup_requests)
logger.info("Warmup with %d request(s)...", self.num_warmup_requests)
self.pipeline(warmup_requests)
# Then run benchmark
if self.cuda_profile_url is not None:
cuda_profiler_start_url = (
self.cuda_profile_url + "/debug/cuda_profiler_start"
)
cuda_profiler_start_response = requests.post(
cuda_profiler_start_url, timeout=60
)
assert cuda_profiler_start_response.status_code == 200
logger.info("Warmup finished. Start benchmarking...")
updated_request_records = self.pipeline(benchmark_requests)
if self.cuda_profile_url is not None:
cuda_profiler_stop_url = self.cuda_profile_url + "/debug/cuda_profiler_stop"
cuda_profiler_stop_response = requests.post(
cuda_profiler_stop_url, timeout=60
)
assert cuda_profiler_stop_response.status_code == 200
return updated_request_records
def _process_warmup_requests(
self, warmup_requests: List[RequestRecord]
) -> List[RequestRecord]:
if len(warmup_requests) == 0:
return warmup_requests
# NOTE: to warm up the server for as more different batch sizes as possible,
# we usese 128 output tokens for the first request and use two more tokens
# for every followup request.
# Setting a high temperature and top-p to avoid early stop as much as possible.
warmup_requests[0].chat_cmpl.max_tokens = 128
for i in range(1, len(warmup_requests)):
warmup_requests[i].chat_cmpl.max_tokens = (
warmup_requests[i - 1].chat_cmpl.max_tokens + 1
)
warmup_requests[i].chat_cmpl.temperature = 2.0
warmup_requests[i].chat_cmpl.top_p = 1.0
return warmup_requests
class SequentialProcessor(RequestProcessor): # pylint: disable=too-few-public-methods
"""The processor that sequentially applies a list of processors in order."""
processors: List[RequestProcessor]
def __init__(self, *processors: RequestProcessor) -> None:
self.processors = list(processors)
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
for processor in self.processors:
request_records = processor(request_records)
return request_records
class Executor(RequestProcessor): # pylint: disable=too-few-public-methods
"""The executor base class, denoting the kind of benchmark mode."""
def __init__(
self,
f_create_api_endpoint: Callable[[], APIEndPoint],
num_processes: int,
disable_tqdm: bool,
) -> None:
self.f_create_api_endpoint = f_create_api_endpoint
self.disable_tqdm = disable_tqdm
self.num_processes = num_processes
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
raise NotImplementedError()
class FixedConcurrentRequestExecutor(
Executor
): # pylint: disable=too-few-public-methods
"""The benchmark executor of fixing the number of concurrent requests."""
def __init__( # pylint: disable=too-many-arguments
self,
f_create_api_endpoint: Callable[[], APIEndPoint],
num_processes: Optional[int],
disable_tqdm: bool,
num_concurrent_requests: int,
multi_round: bool,
) -> None:
if num_processes is None:
# We assign each process at most 32 concurrent requests to send
# so that the asyncio pressure will not be too much.
num_processes = min((num_concurrent_requests + 31) // 32, 10)
super().__init__(f_create_api_endpoint, num_processes, disable_tqdm)
self.num_concurrent_requests = num_concurrent_requests
self.multi_round = multi_round
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
partitions: List[List[RequestRecord]] = [
request_records[slice(i, len(request_records), self.num_processes)]
for i in range(self.num_processes)
]
# Package "tokenizers" reports warnings with multiprocessing.
# We disable "TOKENIZERS_PARALLELISM" to depress the warnings.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
pbar = None if self.disable_tqdm else tqdm(total=len(request_records))
with concurrent.futures.ProcessPoolExecutor(
max_workers=self.num_processes
) as pool:
futures = [
pool.submit(
FixedConcurrentRequestExecutor._process_task,
self.f_create_api_endpoint,
partition,
self.num_concurrent_requests // self.num_processes
+ int(i < self.num_concurrent_requests % self.num_processes),
self.multi_round,
)
for i, partition in enumerate(partitions)
]
results: List[RequestRecord] = []
for future in concurrent.futures.as_completed(futures):
partition_results = future.result()
results.extend(partition_results)
if pbar is not None:
pbar.update(len(partition_results))
return results
@staticmethod
def _process_task(
f_create_api_endpoint: Callable[[], APIEndPoint],
request_records: List[RequestRecord],
num_concurrent_requests: int,
multi_round: bool,
) -> List[RequestRecord]:
if len(request_records) == 0:
return []
chat_history: List[List[ChatCompletionMessage]] = [
[] for _ in range(num_concurrent_requests)
]
async def process_task_impl(
f_create_api_endpoint: Callable[[], APIEndPoint],
request_records: List[RequestRecord],
num_concurrent_requests: int,
multi_round: bool,
) -> List[RequestRecord]:
api_endpoint = f_create_api_endpoint()
updated_request_records: List[RequestRecord] = [
None for _ in request_records
]
async with api_endpoint:
num_sent_request = 0
async def _task(i: int) -> None:
nonlocal num_sent_request
while True:
if num_sent_request == len(request_records):
break
idx = num_sent_request
num_sent_request += 1
request = request_records[idx]
if multi_round:
request.chat_cmpl.messages = (
chat_history[i] + request.chat_cmpl.messages
)
updated_request_records[idx] = await api_endpoint(request)
if multi_round:
chat_history[i] = updated_request_records[
idx
].chat_cmpl.messages + [
ChatCompletionMessage(
content=updated_request_records[idx].output_str,
role="assistant",
)
]
tasks = [
asyncio.create_task(_task(i))
for i in range(num_concurrent_requests)
]
await asyncio.gather(*tasks)
return updated_request_records
return asyncio.run(
process_task_impl(
f_create_api_endpoint,
request_records,
num_concurrent_requests,
multi_round,
)
)
class FixTimestampExecutor(Executor): # pylint: disable=too-few-public-methods
"""The benchmark executor of fixing the timestamps of sending requests."""
def __init__( # pylint: disable=too-many-arguments
self,
f_create_api_endpoint: Callable[[], APIEndPoint],
num_processes: Optional[int],
disable_tqdm: bool,
max_schedule_gap: float,
num_requests: int,
) -> None:
if num_processes is None:
# We assign each process at most 32 requests to send
# so that the asyncio pressure will not be too much.
num_processes = min((num_requests + 31) // 32, 10)
super().__init__(f_create_api_endpoint, num_processes, disable_tqdm)
self.max_schedule_gap = max_schedule_gap
self.num_requests = num_requests
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
assert len(request_records) > 0
assert all(
request_record.timestamp is not None for request_record in request_records
)
# Sort the request records in timestamp ascending order before partitioning.
request_records.sort(key=lambda request_record: request_record.timestamp)
base_timestamp = request_records[0].timestamp
partitions: List[List[RequestRecord]] = [
request_records[slice(i, len(request_records), self.num_processes)]
for i in range(self.num_processes)
]
base_sys_time = time.time()
# Package "tokenizers" reports warnings with multiprocessing.
# We disable "TOKENIZERS_PARALLELISM" to depress the warnings.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
pbar = None if self.disable_tqdm else tqdm(total=len(request_records))
with concurrent.futures.ProcessPoolExecutor(
max_workers=self.num_processes
) as pool:
futures = [
pool.submit(
FixTimestampExecutor._process_task,
self.f_create_api_endpoint,
partition,
base_timestamp,
base_sys_time,
self.max_schedule_gap,
)
for partition in partitions
]
results: List[RequestRecord] = []
for future in concurrent.futures.as_completed(futures):
partition_results = future.result()
results.extend(partition_results)
if pbar is not None:
pbar.update(len(partition_results))
return results
@staticmethod
def _process_task(
f_create_api_endpoint: Callable[[], APIEndPoint],
request_records: List[RequestRecord],
base_timestamp: float,
base_sys_time: float,
max_schedule_gap: float,
) -> List[RequestRecord]:
if len(request_records) == 0:
return []
async def process_task_impl(
f_create_api_endpoint: Callable[[], APIEndPoint],
request_records: List[RequestRecord],
base_timestamp: float,
base_sys_time: float,
max_schedule_gap: float,
) -> List[RequestRecord]:
api_endpoint = f_create_api_endpoint()
loop = asyncio.get_running_loop()
# Get the delta time to convert system time to the loop time.
# We must use the system time `time.time()` which is consistent across processes.
loop_sys_delta_time = loop.time() - time.time()
updated_request_records: List[RequestRecord] = []
async with api_endpoint:
async def _task(request_record: RequestRecord) -> None:
updated_request_records.append(await api_endpoint(request_record))
tasks = []
for request_record in request_records:
launch_time = (
(request_record.timestamp - base_timestamp)
+ (base_sys_time + max_schedule_gap)
+ loop_sys_delta_time
)
loop.call_at(
launch_time,
lambda record: tasks.append(asyncio.create_task(_task(record))),
request_record,
)
# Sleep to allow runs of other scheduled tasks if any.
await asyncio.sleep(
max(launch_time - loop.time() - max_schedule_gap, 0)
)
# Sleep until all the tasks are launched.
await asyncio.sleep(launch_time - loop.time() + max_schedule_gap)
# Wait for all tasks to be scheduled
assert len(tasks) == len(request_records)
await asyncio.gather(*tasks)
assert len(updated_request_records) == len(request_records)
return updated_request_records
return asyncio.run(
process_task_impl(
f_create_api_endpoint,
request_records,
base_timestamp,
base_sys_time,
max_schedule_gap,
)
)
def create_pipelines( # pylint: disable=too-many-branches
args: argparse.Namespace,
f_create_api_endpoint: Callable[[], APIEndPoint],
dataset: Dataset,
) -> List[RequestProcessor]:
"""Creating request processing pipelines with regard to the specified args."""
cuda_profile_url = f"http://{args.host}:{args.port}" if args.cuda_profile else None
pipelines: List[RequestProcessor] = []
if args.num_concurrent_requests is not None:
if args.request_rate is not None:
raise ValueError(
'Both "num_concurrent_requests" and "request_rate" are specified. '
"Please specify only one of them."
)
if args.replay_timestamp_scale is not None:
raise ValueError(
"Dataset replay is unsupported when fixing number of concurrent requests."
)
for num_concurrent_requests in args.num_concurrent_requests:
num_warmup_requests = (
args.num_warmup_requests
if args.num_warmup_requests is not None
else num_concurrent_requests
)
pipelines.append(
SequentialProcessor(
LogMessage(
f"Fixing number of concurrent requests: {num_concurrent_requests}"
),
SampleRequests(args.num_requests + num_warmup_requests),
AttachModelName(args.tokenizer),
AttachStreamFlag(args.stream),
AttachSamplingOptions(
args.temperature, args.top_p, args.ignore_eos
),
AttachExecutionFeature(
{"num_concurrent_requests": num_concurrent_requests}
),
WarmupAndRun(
num_warmup_requests=num_warmup_requests,
num_benchmark_requests=args.num_requests,
pipeline=FixedConcurrentRequestExecutor(
f_create_api_endpoint,
args.num_process_workers,
args.disable_tqdm,
num_concurrent_requests,
args.multi_round,
),
cuda_profile_url=cuda_profile_url,
fake_warmup=dataset.require_fake_warmup,
),
)
)
return pipelines
if args.request_rate is not None:
if args.num_warmup_requests is None:
raise ValueError(
"Please specify the number of warmup requests via "
'"--num-warmup-requests" when fixing request rate.'
)
if args.replay_timestamp_scale is not None:
raise ValueError("Dataset replay is unsupported when fixing request rates.")
num_total_requests = int(
args.num_requests
if not args.per_gpu_workload
else args.num_requests * args.num_gpus
)
if dataset.require_fake_warmup:
num_samples = num_total_requests
else:
num_samples = num_total_requests + args.num_warmup_requests
return [
SequentialProcessor(
LogMessage(f"Fixing request rate: {request_rate}"),
SampleRequests(num_samples),
AttachModelName(args.tokenizer),
AttachRequestRateTimestamp(
request_rate
if not args.per_gpu_workload
else request_rate * args.num_gpus
),
AttachStreamFlag(args.stream),
AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos),
AttachExecutionFeature({"request_rate": float(request_rate)}),
WarmupAndRun(
num_warmup_requests=args.num_warmup_requests,
num_benchmark_requests=num_total_requests,
pipeline=FixTimestampExecutor(
f_create_api_endpoint,
args.num_process_workers,
args.disable_tqdm,
args.max_schedule_gap,
args.num_requests,
),
cuda_profile_url=cuda_profile_url,
fake_warmup=dataset.require_fake_warmup,
),
)
for request_rate in args.request_rate
]
# Default: dataset replay mode
# The dataset must come with timestamps.
if not dataset.timestamp_available:
raise ValueError(
"The dataset does not have timestamps, so dataset replay is unsupported. "
'Please specify one of "num_concurrent_requests" '
'and "request_rate".'
)
if args.per_gpu_workload:
raise ValueError(
"Fixing per-GPU workload is not compatible with dataset replay."
)
if args.num_warmup_requests is None:
raise ValueError(
"Please specify the number of warmup requests via "
'"--num-warmup-requests" for dataset replay.'
)
timestamp_scale = args.replay_timestamp_scale or 1.0
if dataset.require_fake_warmup:
num_samples = args.num_requests
else:
num_samples = args.num_requests + args.num_warmup_requests
return [
SequentialProcessor(
LogMessage(f"Dataset replay with time scaling of {timestamp_scale}"),
SampleRequests(num_samples, take_first_x_requests=True),
AttachModelName(args.tokenizer),
ScaleTimestamp(timestamp_scale),
AttachStreamFlag(args.stream),
AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos),
AttachExecutionFeature({"timestamp_scale": timestamp_scale}),
WarmupAndRun(
num_warmup_requests=args.num_warmup_requests,
num_benchmark_requests=args.num_requests,
pipeline=FixTimestampExecutor(
f_create_api_endpoint,
args.num_process_workers,
args.disable_tqdm,
args.max_schedule_gap,
args.num_requests,
),
cuda_profile_url=cuda_profile_url,
fake_warmup=dataset.require_fake_warmup,
),
)
]
"""Benchmark request records."""
from typing import Any, Dict, List, Optional, Tuple, Union
import pandas as pd # pylint: disable=import-error
from pydantic import BaseModel
from openai_protocol import ChatCompletionRequest
import logging
logger = logging.getLogger(__name__)
class ServerMetrics(BaseModel):
"""The metrics from the server side."""
input_tokens: int
prefill_tokens: int
output_tokens: int
end_to_end_latency_s: float
prefill_tokens_per_s: float
inter_token_latency_s: float
time_per_output_token_s: float
time_to_first_token_s: Optional[float] = None
class Metrics(BaseModel):
"""The list of metric keys"""
success: bool
start_time: float
finish_time: float
end_to_end_latency_s: float
input_tokens: Optional[int] = None
output_tokens: Optional[int] = None
inter_token_latency_s: Optional[float] = None
time_per_output_token_s: Optional[float] = None
time_to_first_token_s: Optional[float] = None
server_metrics: Optional[ServerMetrics] = None
exec_feature: Optional[Dict[str, Any]] = None
class RequestRecord(BaseModel):
"""The request records collected from LLM inference requests."""
request_id: Optional[int] = None
chat_cmpl: ChatCompletionRequest
output_str: Optional[str] = None
first_chunk_output_str: str = ""
timestamp: Optional[float] = None
metrics: Optional[Metrics] = None
error_msg: Optional[str] = None
class GroupedRequestRecord(RequestRecord):
"""The data structure for request record groups.
For datasets that have common prefix sharing, the request records
that share a same common prefix will be wrapped in a GroupedRequestRecord
at the beginning.
"""
records: List[RequestRecord]
def generate_metrics_summary(
request_records: List[RequestRecord],
num_total_requests: int,
num_gpus: int,
) -> Dict[str, Any]:
"""Computes summary statistics across all metrics collected.
Return a dictionary as the report.
"""
num_completed_requests = len(request_records)
assert num_completed_requests <= num_total_requests
request_metrics = [record.metrics for record in request_records]
duration = (
max(metrics.finish_time for metrics in request_metrics)
- min(metrics.start_time for metrics in request_metrics)
if num_completed_requests > 0
else 1e-5
)
report = _compute_metrics_statistics(request_metrics)
report["num_gpus"] = num_gpus
report["duration"] = duration
report["num_total_requests"] = num_total_requests
report["num_completed_requests"] = num_completed_requests
report["request_throughput"] = num_completed_requests / duration
print(request_metrics)
total_input_tokens = sum(metric.input_tokens for metric in request_metrics)
total_output_tokens = sum(metric.output_tokens for metric in request_metrics)
report["total_input_tokens"] = total_input_tokens
report["total_output_tokens"] = total_output_tokens
report["input_token_throughput"] = total_input_tokens / duration
report["input_token_throughput_per_gpu"] = (
report["input_token_throughput"] / num_gpus
)
report["output_token_throughput"] = total_output_tokens / duration
report["output_token_throughput_per_gpu"] = (
report["output_token_throughput"] / num_gpus
)
# Generate the server metrics statistics
server_metrics = [
metric.server_metrics for metric in request_metrics if metric.server_metrics
]
server_report = _compute_metrics_statistics(server_metrics)
if server_report is not None and len(server_report) > 0:
report["server_metrics"] = server_report
report = {
"exec_feature": (
request_records[0].metrics.exec_feature
if num_completed_requests > 0
else None
),
**report,
}
return report
def _compute_metrics_statistics(
metrics: List[Union[Metrics, ServerMetrics]],
) -> Dict[str, Any]:
"""
Compute the statistics of the metrics.
Parameters
----------
metrics : List[Union[Metrics, ServerMetrics]]
The list of metrics to get the statistics.
Returns
-------
report : Dict
The statistics of the metrics.
"""
if not metrics:
return {}
report: Dict = {}
df = pd.DataFrame([metric.model_dump() for metric in metrics])
for key, _ in metrics[0].model_fields.items():
if key in [
"success",
"start_time",
"finish_time",
"server_metrics",
"exec_feature",
]:
continue
if key in df.columns:
series = df[key].dropna()
report[key] = {
"quantiles": {
f"p{int(q * 100)}": v
for q, v in series.quantile(
[0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
).items()
},
"mean": series.mean(),
"min": series.min(),
"max": series.max(),
"stddev": series.std(),
}
return report
def convert_reports_to_df(reports: List[Dict[str, Any]]) -> pd.DataFrame:
"""Convert benchmark reports to pandas DataFrame."""
def _flatten_dict(d: Dict[str, Any], parent_key: str = "") -> Dict[str, Any]:
items: List[Tuple[str, Any]] = []
for key, value in d.items():
new_key = f"{parent_key}.{key}" if parent_key != "" else key
if isinstance(value, dict):
items.extend(_flatten_dict(value, new_key).items())
else:
items.append((new_key, value))
return dict(items)
return pd.DataFrame([_flatten_dict(report) for report in reports])
def pretty_print_report(
report: Dict[str, Any],
) -> None: # pylint: disable=too-many-statements
"""Pretty print the metrics report."""
def _print(
report: Dict[str, Any], server_metrics: bool
): # pylint: disable=too-many-statements
# pylint: disable=line-too-long
# fmt: off
title = "Benchmark Result"
if server_metrics:
title += " (server side)"
print(f" {title} ".center(50, "="))
if not server_metrics:
print(f"{'Total requests:':<40} {report['num_total_requests']:<10}")
print(f"{'Completed requests:':<40} {report['num_completed_requests']:<10}")
print(f"{'Duration (s):':<40} {report['duration']:<10.2f}")
print(f"{'Num GPUs:':<40} {report['num_gpus']:<10}")
print(f"{'Total input tokens:':<40} {report['total_input_tokens']:<10}")
print(f"{'Total output tokens:':<40} {report['total_output_tokens']:<10}")
print(f"{'Request throughput (req/s):':<40} {report['request_throughput']:<10.2f}")
print(f"{'Input token throughput (tok/s):':<40} {report['input_token_throughput']:<10.2f}")
print(f"{'Input token throughput per GPU (tok/s):':<40} {report['input_token_throughput_per_gpu']:<10.2f}")
print(f"{'Output token throughput (tok/s):':<40} {report['output_token_throughput']:<10.2f}")
print(f"{'Output token throughput per GPU (tok/s):':<40} {report['output_token_throughput_per_gpu']:<10.2f}")
if report["num_completed_requests"] == 0:
return
ttft = report["time_to_first_token_s"]
print(" Time to First Token (TTFT, ms) ".center(50, "-"))
print(f"{'Mean:':<40} {ttft['mean'] * 1000:<10.2f}")
print(f"{'Stddev:':<40} {ttft['stddev'] * 1000:<10.2f}")
print(f"{'P25:':<40} {ttft['quantiles']['p25'] * 1000:<10.2f}")
print(f"{'P50:':<40} {ttft['quantiles']['p50'] * 1000:<10.2f}")
print(f"{'P75:':<40} {ttft['quantiles']['p75'] * 1000:<10.2f}")
print(f"{'P90:':<40} {ttft['quantiles']['p90'] * 1000:<10.2f}")
print(f"{'P95:':<40} {ttft['quantiles']['p95'] * 1000:<10.2f}")
print(f"{'P99:':<40} {ttft['quantiles']['p99'] * 1000:<10.2f}")
print(f"{'Min:':<40} {ttft['min'] * 1000:<10.2f}")
print(f"{'Max:':<40} {ttft['max'] * 1000:<10.2f}")
tpot = report["time_per_output_token_s"]
print(" Time per Output Token (TPOT, ms) ".center(50, "-"))
print(f"{'Mean:':<40} {tpot['mean'] * 1000:<10.2f}")
print(f"{'Stddev:':<40} {tpot['stddev'] * 1000:<10.2f}")
print(f"{'P25:':<40} {tpot['quantiles']['p25'] * 1000:<10.2f}")
print(f"{'P50:':<40} {tpot['quantiles']['p50'] * 1000:<10.2f}")
print(f"{'P75:':<40} {tpot['quantiles']['p75'] * 1000:<10.2f}")
print(f"{'P90:':<40} {tpot['quantiles']['p90'] * 1000:<10.2f}")
print(f"{'P95:':<40} {tpot['quantiles']['p95'] * 1000:<10.2f}")
print(f"{'P99:':<40} {tpot['quantiles']['p99'] * 1000:<10.2f}")
print(f"{'Min:':<40} {tpot['min'] * 1000:<10.2f}")
print(f"{'Max:':<40} {tpot['max'] * 1000:<10.2f}")
itl = report["inter_token_latency_s"]
print(" Inter-Token Latency (ms) ".center(50, "-"))
print(f"{'Mean:':<40} {itl['mean'] * 1000:<10.2f}")
print(f"{'Stddev:':<40} {itl['stddev'] * 1000:<10.2f}")
print(f"{'P25:':<40} {itl['quantiles']['p25'] * 1000:<10.2f}")
print(f"{'P50:':<40} {itl['quantiles']['p50'] * 1000:<10.2f}")
print(f"{'P75:':<40} {itl['quantiles']['p75'] * 1000:<10.2f}")
print(f"{'P90:':<40} {itl['quantiles']['p90'] * 1000:<10.2f}")
print(f"{'P95:':<40} {itl['quantiles']['p95'] * 1000:<10.2f}")
print(f"{'P99:':<40} {itl['quantiles']['p99'] * 1000:<10.2f}")
print(f"{'Min:':<40} {itl['min'] * 1000:<10.2f}")
print(f"{'Max:':<40} {itl['max'] * 1000:<10.2f}")
e2e_latency = report["end_to_end_latency_s"]
print(" End-to-End Latency (ms) ".center(50, "-"))
print(f"{'Mean:':<40} {e2e_latency['mean'] * 1000:<10.2f}")
print(f"{'Stddev:':<40} {e2e_latency['stddev'] * 1000:<10.2f}")
print(f"{'P25:':<40} {e2e_latency['quantiles']['p25'] * 1000:<10.2f}")
print(f"{'P50:':<40} {e2e_latency['quantiles']['p50'] * 1000:<10.2f}")
print(f"{'P75:':<40} {e2e_latency['quantiles']['p75'] * 1000:<10.2f}")
print(f"{'P90:':<40} {e2e_latency['quantiles']['p90'] * 1000:<10.2f}")
print(f"{'P95:':<40} {e2e_latency['quantiles']['p95'] * 1000:<10.2f}")
print(f"{'P99:':<40} {e2e_latency['quantiles']['p99'] * 1000:<10.2f}")
print(f"{'Min:':<40} {e2e_latency['min'] * 1000:<10.2f}")
print(f"{'Max:':<40} {e2e_latency['max'] * 1000:<10.2f}")
input_tokens = report["input_tokens"]
print(" Input Tokens ".center(50, "-"))
print(f"{'Mean:':<40} {input_tokens['mean']:<1}")
print(f"{'Stddev:':<40} {input_tokens['stddev']:<1}")
print(f"{'P25:':<40} {input_tokens['quantiles']['p25']:<1}")
print(f"{'P50:':<40} {input_tokens['quantiles']['p50']:<1}")
print(f"{'P95:':<40} {input_tokens['quantiles']['p95']:<1}")
print(f"{'Min:':<40} {input_tokens['min']:<1}")
print(f"{'Max:':<40} {input_tokens['max']:<1}")
output_tokens = report["output_tokens"]
print(" Output Tokens ".center(50, "-"))
print(f"{'Mean:':<40} {output_tokens['mean']:<1}")
print(f"{'Stddev:':<40} {output_tokens['stddev']:<1}")
print(f"{'P25:':<40} {output_tokens['quantiles']['p25']:<1}")
print(f"{'P50:':<40} {output_tokens['quantiles']['p50']:<1}")
print(f"{'P95:':<40} {output_tokens['quantiles']['p95']:<1}")
print(f"{'Min:':<40} {output_tokens['min']:<1}")
print(f"{'Max:':<40} {output_tokens['max']:<1}")
print("=" * 50)
# fmt: on
# pylint: enable=line-too-long
_print(report, server_metrics=False)
if "server_metrics" in report:
_print(report["server_metrics"], server_metrics=True)
#!/usr/bin/env bash
export SERVER_ADDR="127.0.0.1"
export SERVER_PORT="8000"
export MODEL_PATH="Qwen/Qwen3.6-27B" # or the path of other model
export MODEL="Qwen3.6-27B" # or other model names
export TOKENIZER="$MODEL_PATH" # or the path of other tokenizer
export DATA_PATH="./data/dataset"
export ACC_RAW="./data/accuracy_raw"
export ACC_SUM="./data/accuracy_summary"
export DATASET="BFCL_v3_simple"
export REQUEST_NUM=100
export N_GPU=1
python -m sglang.launch_server --model-path $MODEL_PATH \
--host $SERVER_ADDR --port $SERVER_PORT &
SERVER_PID=$!
trap 'kill $SERVER_PID 2>/dev/null; wait $SERVER_PID 2>/dev/null' EXIT
echo "Waiting for sglang on $SERVER_ADDR:$SERVER_PORT ..."
READY=0
for _ in $(seq 1 600); do
if bash -c "echo >/dev/tcp/$SERVER_ADDR/$SERVER_PORT" 2>/dev/null; then
READY=1
break
fi
sleep 1
done
if [ "$READY" != 1 ]; then
echo "Timeout waiting for sglang to listen on $SERVER_ADDR:$SERVER_PORT" >&2
exit 1
fi
python accuracy.py --model $MODEL --tokenizer $TOKENIZER \
--dataset $DATASET --num-requests $REQUEST_NUM \
--dataset-path $DATA_PATH --num-gpus $N_GPU \
--num-warmup-requests 1 --request-rate inf \
--host $SERVER_ADDR --port $SERVER_PORT --api-endpoint sglang --output $ACC_RAW \
--temperature 0.001 --top-p 0.9
python accuracy.py --model $MODEL --tokenizer $TOKENIZER \
--dataset $DATASET --num-requests $REQUEST_NUM \
--dataset-path $DATA_PATH --num-gpus $N_GPU \
--num-warmup-requests 1 --request-rate inf \
--host $SERVER_ADDR --port $SERVER_PORT --api-endpoint sglang --output $ACC_RAW \
--temperature 0.001 --top-p 0.9 \
--use-stag
python check.py --dataset ALL --model ALL --dataset-path $DATA_PATH \
--output-root $ACC_RAW --final-root $ACC_SUM
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment