Last active
November 22, 2024 10:36
-
-
Save eustlb/1fa238d17d4f1f49b19b79c834fb5c7d to your computer and use it in GitHub Desktop.
Benchmark WER and RTFx for transformers whisper.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
TRANSFORMERS_SRC_PATH = "/admin/home/eustache_lebihan/dev/benchmark-whisper/transformers-fix/src" | |
import sys | |
sys.path.insert(0, TRANSFORMERS_SRC_PATH) | |
import wandb | |
from tqdm import tqdm | |
import evaluate | |
import os | |
import torch | |
import argparse | |
from datasets import load_dataset | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer | |
import numpy as np | |
wer_metric = evaluate.load("wer") | |
from git import Repo, GitCommandError | |
def get_latest_commit_and_remote(directory_path): | |
try: | |
repo = Repo(directory_path, search_parent_directories=True) | |
commits = list(repo.iter_commits(paths=directory_path, max_count=1)) | |
latest_commit = commits[0].hexsha if commits else None | |
remote_url = repo.remotes.origin.url if repo.remotes else None | |
return latest_commit, remote_url | |
except GitCommandError as e: | |
print(f"Error: {e}") | |
return None, None | |
class Benchmark: | |
def __init__( | |
self, | |
model_id, | |
dtype, | |
attn_implementation, | |
torch_device="cuda:0", | |
short_form=False, | |
gen_kwargs={ | |
"language": "en", | |
"task": "transcribe", | |
"condition_on_prev_tokens": False, # follow default for OAI | |
"return_timestamps": True, # follow default for OAI | |
"no_speech_threshold": 0.6, | |
"compression_ratio_threshold": 1.35, | |
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), # follow default for OAI | |
"logprob_threshold": -1.0, # follow default for OAI | |
"max_new_tokens": 445, # avoid having it to crash benc current implementation | |
}, | |
): | |
self.model = WhisperForConditionalGeneration.from_pretrained( | |
model_id, | |
torch_dtype=dtype, | |
attn_implementation=attn_implementation, | |
) | |
self.torch_device = torch_device | |
self.model.to(torch_device, dtype=dtype) | |
self.processor = WhisperProcessor.from_pretrained(model_id) | |
self.normalizer = EnglishTextNormalizer( | |
self.processor.tokenizer.english_spelling_normalizer | |
) | |
self.short_form = short_form | |
self.gen_kwargs = gen_kwargs | |
def infer(self, audio_array): | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
if self.short_form: | |
inputs = self.processor(audio_array, return_tensors="pt", sampling_rate=16000) | |
inputs = inputs.to(self.torch_device) | |
else: | |
inputs = self.processor( | |
audio_array, | |
return_tensors="pt", | |
sampling_rate=16000, | |
padding="longest", | |
truncation=False, | |
) | |
inputs = inputs.to(self.torch_device) | |
torch.cuda.synchronize() | |
start_event.record() | |
outputs = self.model.generate( | |
inputs.input_features, | |
**self.gen_kwargs, | |
) | |
end_event.record() | |
torch.cuda.synchronize() | |
transcript = self.processor.decode(outputs[0]) | |
elapsed_time = start_event.elapsed_time(end_event) / 1000 # Add space after # | |
return transcript, elapsed_time | |
def benchmark_dataset( | |
self, dataset_name, dataset_config, dataset_split, audio_column, label_column, short_form | |
): | |
times = [] | |
predictions = [] | |
labels = [] | |
durations = [] | |
dataset = load_dataset(dataset_name, dataset_config, split=dataset_split, trust_remote_code=True) | |
if short_form: | |
dataset = dataset.filter(lambda x: len(x["audio"]["array"]) / x["audio"]["sampling_rate"] <= 30) | |
else: | |
dataset = dataset.filter(lambda x: len(x["audio"]["array"]) / x["audio"]["sampling_rate"] > 30) | |
for sample in tqdm(dataset, desc=f"Benchmarking {dataset_name}"): | |
transcript, inference_time = self.infer(sample[audio_column]["array"]) | |
duration = sample[audio_column]["array"].shape[0] / sample[audio_column]["sampling_rate"] | |
durations.append(duration) | |
times.append(inference_time) | |
predictions.append(transcript) | |
labels.append(sample[label_column]) | |
normalized_predictions = [self.normalizer(pred) for pred in predictions] | |
normalized_labels = [self.normalizer(label) for label in labels] | |
correct_idxs = [i for i in range(len(normalized_labels)) if len(normalized_labels[i]) > 0 and len(normalized_predictions[i]) > 0] | |
normalized_predictions = [normalized_predictions[i] for i in correct_idxs] | |
normalized_labels = [normalized_labels[i] for i in correct_idxs] | |
wer = wer_metric.compute( | |
references=normalized_labels, predictions=normalized_predictions | |
) | |
wer = round(100 * wer, 2) | |
rtfx = round(sum(durations) / sum(times), 2) | |
print(f"{dataset_name}: WER: {wer}%, RTFx: {rtfx}") | |
wandb.log({ | |
f"{dataset_name.replace('/', '_')}_{dataset_config}_{dataset_split}/wer": wer, | |
f"{dataset_name.replace('/', '_')}_{dataset_config}_{dataset_split}/rtfx": rtfx, | |
}) | |
table = wandb.Table( | |
columns=["Prediction", "Target"], | |
data=[ | |
[predictions[i], labels[i]] for i in correct_idxs | |
], | |
) | |
wandb.log({ | |
f"{dataset_name.replace('/', '_')}_{dataset_config}_{dataset_split}/predictions": table | |
}) | |
def benchmark_model( | |
self, | |
dataset_names, | |
dataset_configs, | |
dataset_splits, | |
audio_columns, | |
label_columns, | |
short_form, | |
): | |
for dataset_name, dataset_config, dataset_split, audio_column, label_column in zip(dataset_names, dataset_configs, dataset_splits, audio_columns, label_columns): | |
self.benchmark_dataset(dataset_name, dataset_config, dataset_split, audio_column, label_column, short_form) | |
def main(): | |
parser = argparse.ArgumentParser(description="Benchmark Whisper models using Transformers") | |
parser.add_argument('--model_id', type=str, required=True, help='Hugging Face model ID') | |
parser.add_argument('--dataset_names', type=str, nargs='+', required=True, help='Names of the datasets to benchmark') | |
parser.add_argument('--dataset_configs', type=str, nargs='+', required=True, help='Configs of the datasets to benchmark') | |
parser.add_argument('--dataset_splits', type=str, nargs='+', required=True, help='Splits of the datasets to benchmark') | |
parser.add_argument('--audio_columns', type=str, nargs='+', required=True, help='Audio columns of the datasets') | |
parser.add_argument('--label_columns', type=str, nargs='+', required=True, help='Label columns of the datasets') | |
parser.add_argument('--dtype', type=str, choices=['float16', 'bfloat16', 'float32'], default='float32', help='Data type for model') | |
parser.add_argument('--attn_implementation', type=str, choices=['eager', 'sdpa', 'flash_attention_2'], default='sdpa', help='Attention implementation to use') | |
parser.add_argument('--short_form', action='store_true', help='Whether to use the short form of the dataset') | |
parser.add_argument('--wandb_id', type=str, help='Weights & Biases run ID') | |
args = parser.parse_args() | |
latest_commit, remote_url = get_latest_commit_and_remote(os.path.dirname(TRANSFORMERS_SRC_PATH)) | |
config = { | |
"commit": latest_commit, | |
"remote_url": remote_url, | |
**args.__dict__, | |
} | |
wandb.init(project="benchmark-whisper-short", name=f"transformer-{args.model_id}", config=config, id=args.wandb_id) | |
# Convert dtype string to torch dtype | |
dtype = getattr(torch, args.dtype) | |
# Initialize benchmark | |
benchmark = Benchmark( | |
model_id=args.model_id, | |
dtype=dtype, | |
attn_implementation=args.attn_implementation, | |
short_form=args.short_form, | |
) | |
# Run benchmark | |
benchmark.benchmark_model( | |
dataset_names=args.dataset_names, | |
dataset_configs=args.dataset_configs, | |
dataset_splits=args.dataset_splits, | |
audio_columns=args.audio_columns, | |
label_columns=args.label_columns, | |
short_form=args.short_form, | |
) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
#SBATCH --job-name=benchmark-transformers | |
#SBATCH --ntasks=1 | |
#SBATCH --partition=hopper-prod | |
#SBATCH --qos=normal | |
#SBATCH --cpus-per-task=48 | |
#SBATCH --mem-per-cpu=11G | |
#SBATCH --gres=gpu:h100:1 | |
#SBATCH --time=6:00:00 | |
#SBATCH --requeue | |
#SBATCH --output=/admin/home/eustache_lebihan/dev/logs/%x-%j.out | |
python benchmark_transformers_whisper.py \ | |
--model_id openai/whisper-large-v3 \ | |
--dataset_names distil-whisper/tedlium-long-form distil-whisper/meanwhile \ | |
--dataset_configs default default \ | |
--dataset_splits test test \ | |
--audio_columns audio audio \ | |
--label_columns text text \ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
#SBATCH --job-name=benchmark-transformers | |
#SBATCH --ntasks=1 | |
#SBATCH --partition=hopper-prod | |
#SBATCH --qos=normal | |
#SBATCH --cpus-per-task=48 | |
#SBATCH --mem-per-cpu=11G | |
#SBATCH --gres=gpu:h100:1 | |
#SBATCH --time=1:00:00 | |
#SBATCH --requeue | |
#SBATCH --output=/admin/home/eustache_lebihan/dev/logs/%x-%j.out | |
python benchmark_transformers_whisper.py \ | |
--model_id openai/whisper-large-v3 \ | |
--dataset_names edinburghcstr/ami distil-whisper/chime4 google/fleurs \ | |
--dataset_configs ihm 1-channel en_us \ | |
--dataset_splits test[:100] test[:100] test[:100] \ | |
--audio_columns audio audio audio \ | |
--label_columns text text transcription \ | |
--short_form |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment