Skip to content

Instantly share code, notes, and snippets.

@eustlb
Last active November 22, 2024 10:36
Show Gist options
  • Save eustlb/1fa238d17d4f1f49b19b79c834fb5c7d to your computer and use it in GitHub Desktop.
Save eustlb/1fa238d17d4f1f49b19b79c834fb5c7d to your computer and use it in GitHub Desktop.
Benchmark WER and RTFx for transformers whisper.
TRANSFORMERS_SRC_PATH = "/admin/home/eustache_lebihan/dev/benchmark-whisper/transformers-fix/src"
import sys
sys.path.insert(0, TRANSFORMERS_SRC_PATH)
import wandb
from tqdm import tqdm
import evaluate
import os
import torch
import argparse
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
import numpy as np
wer_metric = evaluate.load("wer")
from git import Repo, GitCommandError
def get_latest_commit_and_remote(directory_path):
try:
repo = Repo(directory_path, search_parent_directories=True)
commits = list(repo.iter_commits(paths=directory_path, max_count=1))
latest_commit = commits[0].hexsha if commits else None
remote_url = repo.remotes.origin.url if repo.remotes else None
return latest_commit, remote_url
except GitCommandError as e:
print(f"Error: {e}")
return None, None
class Benchmark:
def __init__(
self,
model_id,
dtype,
attn_implementation,
torch_device="cuda:0",
short_form=False,
gen_kwargs={
"language": "en",
"task": "transcribe",
"condition_on_prev_tokens": False, # follow default for OAI
"return_timestamps": True, # follow default for OAI
"no_speech_threshold": 0.6,
"compression_ratio_threshold": 1.35,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), # follow default for OAI
"logprob_threshold": -1.0, # follow default for OAI
"max_new_tokens": 445, # avoid having it to crash benc current implementation
},
):
self.model = WhisperForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
attn_implementation=attn_implementation,
)
self.torch_device = torch_device
self.model.to(torch_device, dtype=dtype)
self.processor = WhisperProcessor.from_pretrained(model_id)
self.normalizer = EnglishTextNormalizer(
self.processor.tokenizer.english_spelling_normalizer
)
self.short_form = short_form
self.gen_kwargs = gen_kwargs
def infer(self, audio_array):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if self.short_form:
inputs = self.processor(audio_array, return_tensors="pt", sampling_rate=16000)
inputs = inputs.to(self.torch_device)
else:
inputs = self.processor(
audio_array,
return_tensors="pt",
sampling_rate=16000,
padding="longest",
truncation=False,
)
inputs = inputs.to(self.torch_device)
torch.cuda.synchronize()
start_event.record()
outputs = self.model.generate(
inputs.input_features,
**self.gen_kwargs,
)
end_event.record()
torch.cuda.synchronize()
transcript = self.processor.decode(outputs[0])
elapsed_time = start_event.elapsed_time(end_event) / 1000 # Add space after #
return transcript, elapsed_time
def benchmark_dataset(
self, dataset_name, dataset_config, dataset_split, audio_column, label_column, short_form
):
times = []
predictions = []
labels = []
durations = []
dataset = load_dataset(dataset_name, dataset_config, split=dataset_split, trust_remote_code=True)
if short_form:
dataset = dataset.filter(lambda x: len(x["audio"]["array"]) / x["audio"]["sampling_rate"] <= 30)
else:
dataset = dataset.filter(lambda x: len(x["audio"]["array"]) / x["audio"]["sampling_rate"] > 30)
for sample in tqdm(dataset, desc=f"Benchmarking {dataset_name}"):
transcript, inference_time = self.infer(sample[audio_column]["array"])
duration = sample[audio_column]["array"].shape[0] / sample[audio_column]["sampling_rate"]
durations.append(duration)
times.append(inference_time)
predictions.append(transcript)
labels.append(sample[label_column])
normalized_predictions = [self.normalizer(pred) for pred in predictions]
normalized_labels = [self.normalizer(label) for label in labels]
correct_idxs = [i for i in range(len(normalized_labels)) if len(normalized_labels[i]) > 0 and len(normalized_predictions[i]) > 0]
normalized_predictions = [normalized_predictions[i] for i in correct_idxs]
normalized_labels = [normalized_labels[i] for i in correct_idxs]
wer = wer_metric.compute(
references=normalized_labels, predictions=normalized_predictions
)
wer = round(100 * wer, 2)
rtfx = round(sum(durations) / sum(times), 2)
print(f"{dataset_name}: WER: {wer}%, RTFx: {rtfx}")
wandb.log({
f"{dataset_name.replace('/', '_')}_{dataset_config}_{dataset_split}/wer": wer,
f"{dataset_name.replace('/', '_')}_{dataset_config}_{dataset_split}/rtfx": rtfx,
})
table = wandb.Table(
columns=["Prediction", "Target"],
data=[
[predictions[i], labels[i]] for i in correct_idxs
],
)
wandb.log({
f"{dataset_name.replace('/', '_')}_{dataset_config}_{dataset_split}/predictions": table
})
def benchmark_model(
self,
dataset_names,
dataset_configs,
dataset_splits,
audio_columns,
label_columns,
short_form,
):
for dataset_name, dataset_config, dataset_split, audio_column, label_column in zip(dataset_names, dataset_configs, dataset_splits, audio_columns, label_columns):
self.benchmark_dataset(dataset_name, dataset_config, dataset_split, audio_column, label_column, short_form)
def main():
parser = argparse.ArgumentParser(description="Benchmark Whisper models using Transformers")
parser.add_argument('--model_id', type=str, required=True, help='Hugging Face model ID')
parser.add_argument('--dataset_names', type=str, nargs='+', required=True, help='Names of the datasets to benchmark')
parser.add_argument('--dataset_configs', type=str, nargs='+', required=True, help='Configs of the datasets to benchmark')
parser.add_argument('--dataset_splits', type=str, nargs='+', required=True, help='Splits of the datasets to benchmark')
parser.add_argument('--audio_columns', type=str, nargs='+', required=True, help='Audio columns of the datasets')
parser.add_argument('--label_columns', type=str, nargs='+', required=True, help='Label columns of the datasets')
parser.add_argument('--dtype', type=str, choices=['float16', 'bfloat16', 'float32'], default='float32', help='Data type for model')
parser.add_argument('--attn_implementation', type=str, choices=['eager', 'sdpa', 'flash_attention_2'], default='sdpa', help='Attention implementation to use')
parser.add_argument('--short_form', action='store_true', help='Whether to use the short form of the dataset')
parser.add_argument('--wandb_id', type=str, help='Weights & Biases run ID')
args = parser.parse_args()
latest_commit, remote_url = get_latest_commit_and_remote(os.path.dirname(TRANSFORMERS_SRC_PATH))
config = {
"commit": latest_commit,
"remote_url": remote_url,
**args.__dict__,
}
wandb.init(project="benchmark-whisper-short", name=f"transformer-{args.model_id}", config=config, id=args.wandb_id)
# Convert dtype string to torch dtype
dtype = getattr(torch, args.dtype)
# Initialize benchmark
benchmark = Benchmark(
model_id=args.model_id,
dtype=dtype,
attn_implementation=args.attn_implementation,
short_form=args.short_form,
)
# Run benchmark
benchmark.benchmark_model(
dataset_names=args.dataset_names,
dataset_configs=args.dataset_configs,
dataset_splits=args.dataset_splits,
audio_columns=args.audio_columns,
label_columns=args.label_columns,
short_form=args.short_form,
)
if __name__ == "__main__":
main()
#!/bin/bash
#SBATCH --job-name=benchmark-transformers
#SBATCH --ntasks=1
#SBATCH --partition=hopper-prod
#SBATCH --qos=normal
#SBATCH --cpus-per-task=48
#SBATCH --mem-per-cpu=11G
#SBATCH --gres=gpu:h100:1
#SBATCH --time=6:00:00
#SBATCH --requeue
#SBATCH --output=/admin/home/eustache_lebihan/dev/logs/%x-%j.out
python benchmark_transformers_whisper.py \
--model_id openai/whisper-large-v3 \
--dataset_names distil-whisper/tedlium-long-form distil-whisper/meanwhile \
--dataset_configs default default \
--dataset_splits test test \
--audio_columns audio audio \
--label_columns text text \
#!/bin/bash
#SBATCH --job-name=benchmark-transformers
#SBATCH --ntasks=1
#SBATCH --partition=hopper-prod
#SBATCH --qos=normal
#SBATCH --cpus-per-task=48
#SBATCH --mem-per-cpu=11G
#SBATCH --gres=gpu:h100:1
#SBATCH --time=1:00:00
#SBATCH --requeue
#SBATCH --output=/admin/home/eustache_lebihan/dev/logs/%x-%j.out
python benchmark_transformers_whisper.py \
--model_id openai/whisper-large-v3 \
--dataset_names edinburghcstr/ami distil-whisper/chime4 google/fleurs \
--dataset_configs ihm 1-channel en_us \
--dataset_splits test[:100] test[:100] test[:100] \
--audio_columns audio audio audio \
--label_columns text text transcription \
--short_form
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment