Last active
November 21, 2024 22:40
-
-
Save eustlb/1828634ea898b77e47ebdaa38268cf83 to your computer and use it in GitHub Desktop.
Benchmark WER and RTFx for openai whisper.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
OPENAI_SRC_PATH = "/admin/home/eustache_lebihan/dev/benchmark-whisper/whisper" | |
import sys | |
sys.path.insert(0, OPENAI_SRC_PATH) | |
import wandb | |
from tqdm import tqdm | |
import evaluate | |
import os | |
import torch | |
import argparse | |
from datasets import load_dataset | |
import whisper | |
from transformers import WhisperProcessor | |
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer | |
wer_metric = evaluate.load("wer") | |
from git import Repo, GitCommandError | |
def get_latest_commit_and_remote(directory_path): | |
try: | |
repo = Repo(directory_path, search_parent_directories=True) | |
commits = list(repo.iter_commits(paths=directory_path, max_count=1)) | |
latest_commit = commits[0].hexsha if commits else None | |
remote_url = repo.remotes.origin.url if repo.remotes else None | |
return latest_commit, remote_url | |
except GitCommandError as e: | |
print(f"Error: {e}") | |
return None, None | |
class Benchmark: | |
def __init__( | |
self, | |
model_id, | |
torch_device="cuda:0", | |
short_form=False, | |
gen_kwargs={ | |
"language": "en", | |
"task": "transcribe", | |
"fp16": False, | |
"condition_on_previous_text": True, | |
"no_speech_threshold": 0.6, | |
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), | |
"logprob_threshold": -1000, | |
"compression_ratio_threshold": 1000, | |
"sample_len": 224, # match OAI | |
}, | |
): | |
self.model = whisper.load_model(model_id) | |
self.torch_device = torch_device | |
self.model.to(torch_device) | |
self.processor = WhisperProcessor.from_pretrained(f"openai/whisper-{model_id}") | |
self.normalizer = EnglishTextNormalizer( | |
self.processor.tokenizer.english_spelling_normalizer | |
) | |
self.short_form = short_form | |
self.gen_kwargs = gen_kwargs | |
def infer(self, audio_array): | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
torch.cuda.synchronize() | |
start_event.record() | |
outputs = self.model.transcribe( | |
audio_array.astype("float32"), | |
**self.gen_kwargs, | |
) | |
end_event.record() | |
torch.cuda.synchronize() | |
transcript = outputs["text"] | |
elapsed_time = start_event.elapsed_time(end_event) / 1000 # Add space after # | |
return transcript, elapsed_time | |
def benchmark_dataset( | |
self, dataset_name, dataset_config, dataset_split, audio_column, label_column, short_form | |
): | |
times = [] | |
predictions = [] | |
labels = [] | |
durations = [] | |
dataset = load_dataset(dataset_name, dataset_config, split=dataset_split, trust_remote_code=True) | |
if short_form: | |
dataset = dataset.filter(lambda x: len(x["audio"]["array"]) / x["audio"]["sampling_rate"] <= 30) | |
else: | |
dataset = dataset.filter(lambda x: len(x["audio"]["array"]) / x["audio"]["sampling_rate"] > 30) | |
for sample in tqdm(dataset, desc=f"Benchmarking {dataset_name}"): | |
transcript, inference_time = self.infer(sample[audio_column]["array"]) | |
duration = sample[audio_column]["array"].shape[0] / sample[audio_column]["sampling_rate"] | |
durations.append(duration) | |
times.append(inference_time) | |
predictions.append(transcript) | |
labels.append(sample[label_column]) | |
normalized_predictions = [self.normalizer(pred) for pred in predictions] | |
normalized_labels = [self.normalizer(label) for label in labels] | |
correct_idxs = [i for i in range(len(normalized_labels)) if len(normalized_labels[i]) > 0 and len(normalized_predictions[i]) > 0] | |
normalized_predictions = [normalized_predictions[i] for i in correct_idxs] | |
normalized_labels = [normalized_labels[i] for i in correct_idxs] | |
try: | |
wer = wer_metric.compute( | |
references=normalized_labels, predictions=normalized_predictions | |
) | |
except Exception as e: | |
print(f"Error computing WER: {e}") | |
return | |
wer = round(100 * wer, 2) | |
rtfx = round(sum(durations) / sum(times), 2) | |
print(f"{dataset_name}: WER: {wer}%, RTFx: {rtfx}") | |
wandb.log({ | |
f"{dataset_name.replace('/', '_')}_{dataset_config}_{dataset_split}/wer": wer, | |
f"{dataset_name.replace('/', '_')}_{dataset_config}_{dataset_split}/rtfx": rtfx, | |
}) | |
table = wandb.Table( | |
columns=["Prediction", "Target"], | |
data=[ | |
[predictions[i], labels[i]] for i in correct_idxs | |
], | |
) | |
wandb.log({ | |
f"{dataset_name.replace('/', '_')}_{dataset_config}_{dataset_split}/predictions": table | |
}) | |
def benchmark_model( | |
self, | |
dataset_names, | |
dataset_configs, | |
dataset_splits, | |
audio_columns, | |
label_columns, | |
short_form, | |
): | |
for dataset_name, dataset_config, dataset_split, audio_column, label_column in zip(dataset_names, dataset_configs, dataset_splits, audio_columns, label_columns): | |
self.benchmark_dataset(dataset_name, dataset_config, dataset_split, audio_column, label_column, short_form) | |
def main(): | |
parser = argparse.ArgumentParser(description="Benchmark Whisper models using Transformers") | |
parser.add_argument('--model_id', type=str, required=True, help='Hugging Face model ID') | |
parser.add_argument('--dataset_names', type=str, nargs='+', required=True, help='Names of the datasets to benchmark') | |
parser.add_argument('--dataset_configs', type=str, nargs='+', required=True, help='Configs of the datasets to benchmark') | |
parser.add_argument('--dataset_splits', type=str, nargs='+', required=True, help='Splits of the datasets to benchmark') | |
parser.add_argument('--audio_columns', type=str, nargs='+', required=True, help='Audio columns of the datasets') | |
parser.add_argument('--label_columns', type=str, nargs='+', required=True, help='Label columns of the datasets') | |
parser.add_argument('--short_form', action='store_true', help='Whether to use the short form of the dataset') | |
parser.add_argument('--wandb_id', type=str, help='Weights & Biases run ID') | |
args = parser.parse_args() | |
latest_commit, remote_url = get_latest_commit_and_remote(OPENAI_SRC_PATH) | |
config = { | |
"commit": latest_commit, | |
"remote_url": remote_url, | |
**args.__dict__, | |
} | |
wandb.init(project="benchmark-whisper-short", name=f"openai-{args.model_id}", config=config, id=args.wandb_id) | |
# Initialize benchmark | |
benchmark = Benchmark( | |
model_id=args.model_id, | |
short_form=args.short_form, | |
) | |
# Run benchmark | |
benchmark.benchmark_model( | |
dataset_names=args.dataset_names, | |
dataset_configs=args.dataset_configs, | |
dataset_splits=args.dataset_splits, | |
audio_columns=args.audio_columns, | |
label_columns=args.label_columns, | |
short_form=args.short_form, | |
) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
#SBATCH --job-name=benchmark-openai | |
#SBATCH --ntasks=1 | |
#SBATCH --partition=hopper-prod | |
#SBATCH --qos=normal | |
#SBATCH --cpus-per-task=48 | |
#SBATCH --mem-per-cpu=11G | |
#SBATCH --gres=gpu:h100:1 | |
#SBATCH --time=1:00:00 | |
#SBATCH --requeue | |
#SBATCH --output=/admin/home/eustache_lebihan/dev/logs/%x-%j.out | |
python benchmark_openai_whisper.py \ | |
--model_id large-v3 \ | |
--dataset_names distil-whisper/tedlium-long-form distil-whisper/meanwhile \ | |
--dataset_configs default default \ | |
--dataset_splits test test \ | |
--audio_columns audio audio \ | |
--label_columns text text \ | |
--nodes=1 --gpus=1 --qos=high --partition=hopper-dev |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
#SBATCH --job-name=benchmark-openai | |
#SBATCH --ntasks=1 | |
#SBATCH --partition=hopper-prod | |
#SBATCH --qos=normal | |
#SBATCH --cpus-per-task=48 | |
#SBATCH --mem-per-cpu=11G | |
#SBATCH --gres=gpu:h100:1 | |
#SBATCH --time=4:00:00 | |
#SBATCH --requeue | |
#SBATCH --output=/admin/home/eustache_lebihan/dev/logs/%x-%j.out | |
python benchmark_openai_whisper.py \ | |
--model_id large-v3 \ | |
--dataset_names edinburghcstr/ami distil-whisper/chime4 google/fleurs \ | |
--dataset_configs ihm 1-channel en_us \ | |
--dataset_splits test test test \ | |
--audio_columns audio audio audio \ | |
--label_columns text text transcription \ | |
--short_form |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment