Skip to content

Instantly share code, notes, and snippets.

@eustlb
Last active November 21, 2024 22:40
Show Gist options
  • Save eustlb/1828634ea898b77e47ebdaa38268cf83 to your computer and use it in GitHub Desktop.
Save eustlb/1828634ea898b77e47ebdaa38268cf83 to your computer and use it in GitHub Desktop.
Benchmark WER and RTFx for openai whisper.
OPENAI_SRC_PATH = "/admin/home/eustache_lebihan/dev/benchmark-whisper/whisper"
import sys
sys.path.insert(0, OPENAI_SRC_PATH)
import wandb
from tqdm import tqdm
import evaluate
import os
import torch
import argparse
from datasets import load_dataset
import whisper
from transformers import WhisperProcessor
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
wer_metric = evaluate.load("wer")
from git import Repo, GitCommandError
def get_latest_commit_and_remote(directory_path):
try:
repo = Repo(directory_path, search_parent_directories=True)
commits = list(repo.iter_commits(paths=directory_path, max_count=1))
latest_commit = commits[0].hexsha if commits else None
remote_url = repo.remotes.origin.url if repo.remotes else None
return latest_commit, remote_url
except GitCommandError as e:
print(f"Error: {e}")
return None, None
class Benchmark:
def __init__(
self,
model_id,
torch_device="cuda:0",
short_form=False,
gen_kwargs={
"language": "en",
"task": "transcribe",
"fp16": False,
"condition_on_previous_text": True,
"no_speech_threshold": 0.6,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
"logprob_threshold": -1000,
"compression_ratio_threshold": 1000,
"sample_len": 224, # match OAI
},
):
self.model = whisper.load_model(model_id)
self.torch_device = torch_device
self.model.to(torch_device)
self.processor = WhisperProcessor.from_pretrained(f"openai/whisper-{model_id}")
self.normalizer = EnglishTextNormalizer(
self.processor.tokenizer.english_spelling_normalizer
)
self.short_form = short_form
self.gen_kwargs = gen_kwargs
def infer(self, audio_array):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
outputs = self.model.transcribe(
audio_array.astype("float32"),
**self.gen_kwargs,
)
end_event.record()
torch.cuda.synchronize()
transcript = outputs["text"]
elapsed_time = start_event.elapsed_time(end_event) / 1000 # Add space after #
return transcript, elapsed_time
def benchmark_dataset(
self, dataset_name, dataset_config, dataset_split, audio_column, label_column, short_form
):
times = []
predictions = []
labels = []
durations = []
dataset = load_dataset(dataset_name, dataset_config, split=dataset_split, trust_remote_code=True)
if short_form:
dataset = dataset.filter(lambda x: len(x["audio"]["array"]) / x["audio"]["sampling_rate"] <= 30)
else:
dataset = dataset.filter(lambda x: len(x["audio"]["array"]) / x["audio"]["sampling_rate"] > 30)
for sample in tqdm(dataset, desc=f"Benchmarking {dataset_name}"):
transcript, inference_time = self.infer(sample[audio_column]["array"])
duration = sample[audio_column]["array"].shape[0] / sample[audio_column]["sampling_rate"]
durations.append(duration)
times.append(inference_time)
predictions.append(transcript)
labels.append(sample[label_column])
normalized_predictions = [self.normalizer(pred) for pred in predictions]
normalized_labels = [self.normalizer(label) for label in labels]
correct_idxs = [i for i in range(len(normalized_labels)) if len(normalized_labels[i]) > 0 and len(normalized_predictions[i]) > 0]
normalized_predictions = [normalized_predictions[i] for i in correct_idxs]
normalized_labels = [normalized_labels[i] for i in correct_idxs]
try:
wer = wer_metric.compute(
references=normalized_labels, predictions=normalized_predictions
)
except Exception as e:
print(f"Error computing WER: {e}")
return
wer = round(100 * wer, 2)
rtfx = round(sum(durations) / sum(times), 2)
print(f"{dataset_name}: WER: {wer}%, RTFx: {rtfx}")
wandb.log({
f"{dataset_name.replace('/', '_')}_{dataset_config}_{dataset_split}/wer": wer,
f"{dataset_name.replace('/', '_')}_{dataset_config}_{dataset_split}/rtfx": rtfx,
})
table = wandb.Table(
columns=["Prediction", "Target"],
data=[
[predictions[i], labels[i]] for i in correct_idxs
],
)
wandb.log({
f"{dataset_name.replace('/', '_')}_{dataset_config}_{dataset_split}/predictions": table
})
def benchmark_model(
self,
dataset_names,
dataset_configs,
dataset_splits,
audio_columns,
label_columns,
short_form,
):
for dataset_name, dataset_config, dataset_split, audio_column, label_column in zip(dataset_names, dataset_configs, dataset_splits, audio_columns, label_columns):
self.benchmark_dataset(dataset_name, dataset_config, dataset_split, audio_column, label_column, short_form)
def main():
parser = argparse.ArgumentParser(description="Benchmark Whisper models using Transformers")
parser.add_argument('--model_id', type=str, required=True, help='Hugging Face model ID')
parser.add_argument('--dataset_names', type=str, nargs='+', required=True, help='Names of the datasets to benchmark')
parser.add_argument('--dataset_configs', type=str, nargs='+', required=True, help='Configs of the datasets to benchmark')
parser.add_argument('--dataset_splits', type=str, nargs='+', required=True, help='Splits of the datasets to benchmark')
parser.add_argument('--audio_columns', type=str, nargs='+', required=True, help='Audio columns of the datasets')
parser.add_argument('--label_columns', type=str, nargs='+', required=True, help='Label columns of the datasets')
parser.add_argument('--short_form', action='store_true', help='Whether to use the short form of the dataset')
parser.add_argument('--wandb_id', type=str, help='Weights & Biases run ID')
args = parser.parse_args()
latest_commit, remote_url = get_latest_commit_and_remote(OPENAI_SRC_PATH)
config = {
"commit": latest_commit,
"remote_url": remote_url,
**args.__dict__,
}
wandb.init(project="benchmark-whisper-short", name=f"openai-{args.model_id}", config=config, id=args.wandb_id)
# Initialize benchmark
benchmark = Benchmark(
model_id=args.model_id,
short_form=args.short_form,
)
# Run benchmark
benchmark.benchmark_model(
dataset_names=args.dataset_names,
dataset_configs=args.dataset_configs,
dataset_splits=args.dataset_splits,
audio_columns=args.audio_columns,
label_columns=args.label_columns,
short_form=args.short_form,
)
if __name__ == "__main__":
main()
#!/bin/bash
#SBATCH --job-name=benchmark-openai
#SBATCH --ntasks=1
#SBATCH --partition=hopper-prod
#SBATCH --qos=normal
#SBATCH --cpus-per-task=48
#SBATCH --mem-per-cpu=11G
#SBATCH --gres=gpu:h100:1
#SBATCH --time=1:00:00
#SBATCH --requeue
#SBATCH --output=/admin/home/eustache_lebihan/dev/logs/%x-%j.out
python benchmark_openai_whisper.py \
--model_id large-v3 \
--dataset_names distil-whisper/tedlium-long-form distil-whisper/meanwhile \
--dataset_configs default default \
--dataset_splits test test \
--audio_columns audio audio \
--label_columns text text \
--nodes=1 --gpus=1 --qos=high --partition=hopper-dev
#!/bin/bash
#SBATCH --job-name=benchmark-openai
#SBATCH --ntasks=1
#SBATCH --partition=hopper-prod
#SBATCH --qos=normal
#SBATCH --cpus-per-task=48
#SBATCH --mem-per-cpu=11G
#SBATCH --gres=gpu:h100:1
#SBATCH --time=4:00:00
#SBATCH --requeue
#SBATCH --output=/admin/home/eustache_lebihan/dev/logs/%x-%j.out
python benchmark_openai_whisper.py \
--model_id large-v3 \
--dataset_names edinburghcstr/ami distil-whisper/chime4 google/fleurs \
--dataset_configs ihm 1-channel en_us \
--dataset_splits test test test \
--audio_columns audio audio audio \
--label_columns text text transcription \
--short_form
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment