Skip to content

Instantly share code, notes, and snippets.

@litagin02
Last active March 13, 2024 12:07
Show Gist options
  • Save litagin02/0fa2b6d47d5376eae52053cf7708798a to your computer and use it in GitHub Desktop.
Save litagin02/0fa2b6d47d5376eae52053cf7708798a to your computer and use it in GitHub Desktop.
from pathlib import Path
from typing import Any, Optional
import tqdm
from torch.utils.data import Dataset
# HF pipelineで進捗表示をするために必要なDatasetクラス
class StrListDataset(Dataset[str]):
def __init__(self, original_list: list[str]) -> None:
self.original_list = original_list
def __len__(self) -> int:
return len(self.original_list)
def __getitem__(self, i: int) -> str:
return self.original_list[i]
def transcribe_files_with_hf_whisper(
audio_files: list[Path],
model_id: str,
initial_prompt: Optional[str] = None,
language: str = "ja",
batch_size: int = 16,
num_beams: int = 1,
device: str = "cuda",
pbar: Optional[tqdm] = None,
) -> list[str]:
import torch
from transformers import WhisperProcessor, pipeline
processor: WhisperProcessor = WhisperProcessor.from_pretrained(model_id)
generate_kwargs: dict[str, Any] = {
"language": language,
"do_sample": False,
"num_beams": num_beams,
}
if initial_prompt is not None:
prompt_ids: torch.Tensor = processor.get_prompt_ids(
initial_prompt, return_tensors="pt"
)
prompt_ids = prompt_ids.to(device)
generate_kwargs["prompt_ids"] = prompt_ids
pipe = pipeline(
model=model_id,
max_new_tokens=128,
chunk_length_s=30,
batch_size=batch_size,
torch_dtype=torch.float16,
device="cuda",
generate_kwargs=generate_kwargs,
)
dataset = StrListDataset([str(f) for f in audio_files])
results: list[str] = []
for whisper_result in pipe(dataset):
text: str = whisper_result["text"]
# なぜかテキストの最初に" {initial_prompt}"が入るので、文字の最初からこれを削除する
# cf. https://github.com/huggingface/transformers/issues/27594
if text.startswith(f" {initial_prompt}"):
text = text[len(f" {initial_prompt}") :]
results.append(text)
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.close()
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment