Last active
May 27, 2023 15:36
-
-
Save dodysw/e0af41d990def8bbeb8934eeabd77688 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script contains adjusted cli codes from https://github.com/openai/whisper but uses https://github.com/guillaumekln/faster-whisper instead | |
Initialize models using this step: | |
mkdir -p ~/.cache/faster-whisper | |
# assuming a python venv installed on venv dir.... | |
venv/bin/python -m pip install --force-reinstall "faster-whisper @ https://github.com/guillaumekln/faster-whisper/archive/refs/heads/master.tar.gz" | |
venv/bin/ct2-transformers-converter --model openai/whisper-large-v2 --output_dir ~/.cache/faster-whisper/whisper-large-v2-ct2 --copy_files tokenizer.json --quantization float16 | |
venv/bin/ct2-transformers-converter --model openai/whisper-medium --output_dir ~/.cache/faster-whisper/whisper-medium-ct2 --copy_files tokenizer.json --quantization float16 | |
# then use it like you used whisper script: | |
venv/bin/python fasterw.py audio.wav <same whisper parameters> | |
""" | |
import argparse | |
import json | |
import logging | |
import os.path | |
import sys | |
import typing | |
import faster_whisper | |
import ffmpeg | |
import numpy as np | |
import torch | |
default_model_dir = "~/.cache/faster-whisper" | |
model_map = { | |
'medium': 'whisper-medium-ct2', | |
'large': 'whisper-large-v2-ct2' | |
} | |
LANGUAGES = { | |
"en": "english", | |
"zh": "chinese", | |
"de": "german", | |
"es": "spanish", | |
"ru": "russian", | |
"ko": "korean", | |
"fr": "french", | |
"ja": "japanese", | |
"pt": "portuguese", | |
"tr": "turkish", | |
"pl": "polish", | |
"ca": "catalan", | |
"nl": "dutch", | |
"ar": "arabic", | |
"sv": "swedish", | |
"it": "italian", | |
"id": "indonesian", | |
"hi": "hindi", | |
"fi": "finnish", | |
"vi": "vietnamese", | |
"he": "hebrew", | |
"uk": "ukrainian", | |
"el": "greek", | |
"ms": "malay", | |
"cs": "czech", | |
"ro": "romanian", | |
"da": "danish", | |
"hu": "hungarian", | |
"ta": "tamil", | |
"no": "norwegian", | |
"th": "thai", | |
"ur": "urdu", | |
"hr": "croatian", | |
"bg": "bulgarian", | |
"lt": "lithuanian", | |
"la": "latin", | |
"mi": "maori", | |
"ml": "malayalam", | |
"cy": "welsh", | |
"sk": "slovak", | |
"te": "telugu", | |
"fa": "persian", | |
"lv": "latvian", | |
"bn": "bengali", | |
"sr": "serbian", | |
"az": "azerbaijani", | |
"sl": "slovenian", | |
"kn": "kannada", | |
"et": "estonian", | |
"mk": "macedonian", | |
"br": "breton", | |
"eu": "basque", | |
"is": "icelandic", | |
"hy": "armenian", | |
"ne": "nepali", | |
"mn": "mongolian", | |
"bs": "bosnian", | |
"kk": "kazakh", | |
"sq": "albanian", | |
"sw": "swahili", | |
"gl": "galician", | |
"mr": "marathi", | |
"pa": "punjabi", | |
"si": "sinhala", | |
"km": "khmer", | |
"sn": "shona", | |
"yo": "yoruba", | |
"so": "somali", | |
"af": "afrikaans", | |
"oc": "occitan", | |
"ka": "georgian", | |
"be": "belarusian", | |
"tg": "tajik", | |
"sd": "sindhi", | |
"gu": "gujarati", | |
"am": "amharic", | |
"yi": "yiddish", | |
"lo": "lao", | |
"uz": "uzbek", | |
"fo": "faroese", | |
"ht": "haitian creole", | |
"ps": "pashto", | |
"tk": "turkmen", | |
"nn": "nynorsk", | |
"mt": "maltese", | |
"sa": "sanskrit", | |
"lb": "luxembourgish", | |
"my": "myanmar", | |
"bo": "tibetan", | |
"tl": "tagalog", | |
"mg": "malagasy", | |
"as": "assamese", | |
"tt": "tatar", | |
"haw": "hawaiian", | |
"ln": "lingala", | |
"ha": "hausa", | |
"ba": "bashkir", | |
"jw": "javanese", | |
"su": "sundanese", | |
} | |
# language code lookup by name, with a few language aliases | |
TO_LANGUAGE_CODE = { | |
**{language: code for code, language in LANGUAGES.items()}, | |
"burmese": "my", | |
"valencian": "ca", | |
"flemish": "nl", | |
"haitian": "ht", | |
"letzeburgesch": "lb", | |
"pushto": "ps", | |
"panjabi": "pa", | |
"moldavian": "ro", | |
"moldovan": "ro", | |
"sinhalese": "si", | |
"castilian": "es", | |
} | |
def available_models(): | |
return model_map.keys() | |
def optional_int(string): | |
return None if string == "None" else int(string) | |
def optional_float(string): | |
return None if string == "None" else float(string) | |
def str2bool(string): | |
str2val = {"True": True, "False": False} | |
if string in str2val: | |
return str2val[string] | |
else: | |
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") | |
def format_timestamp( | |
seconds: float, always_include_hours: bool = False, decimal_marker: str = "." | |
): | |
assert seconds >= 0, "non-negative timestamp expected" | |
milliseconds = round(seconds * 1000.0) | |
hours = milliseconds // 3_600_000 | |
milliseconds -= hours * 3_600_000 | |
minutes = milliseconds // 60_000 | |
milliseconds -= minutes * 60_000 | |
seconds = milliseconds // 1_000 | |
milliseconds -= seconds * 1_000 | |
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" | |
return ( | |
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" | |
) | |
class ResultWriter: | |
extension: str | |
def __init__(self, output_dir: str): | |
self.output_dir = output_dir | |
def __call__(self, result: dict, audio_path: str): | |
audio_basename = os.path.basename(audio_path) | |
audio_basename = os.path.splitext(audio_basename)[0] | |
output_path = os.path.join( | |
self.output_dir, audio_basename + "." + self.extension | |
) | |
with open(output_path, "w", encoding="utf-8") as f: | |
self.write_result(result, file=f) | |
def write_result(self, result: dict, file: typing.TextIO): | |
raise NotImplementedError | |
class WriteTXT(ResultWriter): | |
extension: str = "txt" | |
def write_result(self, result: dict, file: typing.TextIO): | |
for segment in result["segments"]: | |
print(segment.text.strip(), file=file, flush=True) | |
class SubtitlesWriter(ResultWriter): | |
always_include_hours: bool | |
decimal_marker: str | |
def iterate_result(self, result: dict): | |
for segment in result["segments"]: | |
segment_start = self.format_timestamp(segment.start) | |
segment_end = self.format_timestamp(segment.end) | |
segment_text = segment.text.strip().replace("-->", "->") | |
if word_timings := segment.words: | |
all_words = [timing.word for timing in word_timings] | |
all_words[0] = all_words[0].strip() # remove the leading space, if any | |
last = segment_start | |
for i, this_word in enumerate(word_timings): | |
start = self.format_timestamp(this_word.start) | |
end = self.format_timestamp(this_word.end) | |
if last != start: | |
yield last, start, segment_text | |
yield start, end, "".join( | |
[ | |
f"<u>{word}</u>" if j == i else word | |
for j, word in enumerate(all_words) | |
] | |
) | |
last = end | |
if last != segment_end: | |
yield last, segment_end, segment_text | |
else: | |
yield segment_start, segment_end, segment_text | |
def format_timestamp(self, seconds: float): | |
return format_timestamp( | |
seconds=seconds, | |
always_include_hours=self.always_include_hours, | |
decimal_marker=self.decimal_marker, | |
) | |
class WriteVTT(SubtitlesWriter): | |
extension: str = "vtt" | |
always_include_hours: bool = False | |
decimal_marker: str = "." | |
def write_result(self, result: dict, file: typing.TextIO): | |
print("WEBVTT\n", file=file) | |
for start, end, text in self.iterate_result(result): | |
print(f"{start} --> {end}\n{text}\n", file=file, flush=True) | |
class WriteSRT(SubtitlesWriter): | |
extension: str = "srt" | |
always_include_hours: bool = True | |
decimal_marker: str = "," | |
def write_result(self, result: dict, file: typing.TextIO): | |
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1): | |
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) | |
class WriteTSV(ResultWriter): | |
""" | |
Write a transcript to a file in TSV (tab-separated values) format containing lines like: | |
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text> | |
Using integer milliseconds as start and end times means there's no chance of interference from | |
an environment setting a language encoding that causes the decimal in a floating point number | |
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. | |
""" | |
extension: str = "tsv" | |
def write_result(self, result: dict, file: typing.TextIO): | |
print("start", "end", "text", sep="\t", file=file) | |
for segment in result["segments"]: | |
print(round(1000 * segment.start), file=file, end="\t") | |
print(round(1000 * segment.end), file=file, end="\t") | |
print(segment.text.strip().replace("\t", " "), file=file, flush=True) | |
class WriteJSON(ResultWriter): | |
extension: str = "json" | |
def write_result(self, result: dict, file: typing.TextIO): | |
json.dump(result, file) | |
def get_writer(output_format: str, output_dir: str) -> typing.Callable[[dict, typing.TextIO], None]: | |
writers = { | |
"txt": WriteTXT, | |
"vtt": WriteVTT, | |
"srt": WriteSRT, | |
"tsv": WriteTSV, | |
# "json": WriteJSON, | |
} | |
if output_format == "all": | |
all_writers = [writer(output_dir) for writer in writers.values()] | |
def write_all(result: dict, file: typing.TextIO): | |
for writer in all_writers: | |
writer(result, file) | |
return write_all | |
return writers[output_format](output_dir) | |
system_encoding = sys.getdefaultencoding() | |
if system_encoding != "utf-8": | |
def make_safe(string): | |
# replaces any character not representable using the system default encoding with an '?', | |
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). | |
return string.encode(system_encoding, errors="replace").decode(system_encoding) | |
else: | |
def make_safe(string): | |
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding | |
return string | |
def load_audio(file: str, sound_filter): | |
""" | |
Adjusted from openai whisper code, this code performs additional works: | |
1) filter to fill in ptc time gap with silents, up to 3600 seconds. Input file must be container that support storing ptc like mp4, mkv, .ts. | |
2) normalize volume to increase speech recognition of soft voice / too far from microphone | |
3) (as is) down sample to 16kbps bit rate, mono, as expected by the model encoder | |
""" | |
try: | |
ffmpeg_params = ffmpeg.input(file, threads=0) \ | |
.audio \ | |
.filter("aresample", **{"async": 3600}) | |
if sound_filter == 1: | |
ffmpeg_params = ffmpeg_params.filter("speechnorm", **{"e": 50, "r": 0.0001, "l": 1}) | |
elif sound_filter == 2: | |
ffmpeg_params = ffmpeg_params.filter("loudnorm") | |
ffmpeg_params = ffmpeg_params.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000) | |
print('sound filering using:', ' '.join(ffmpeg_params.get_args())) | |
out, _ = ffmpeg_params.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) | |
except ffmpeg.Error as e: | |
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e | |
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 | |
def cli(): | |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") | |
parser.add_argument("--model", default="medium", choices=available_models(), help="name of the Whisper model to use") | |
parser.add_argument("--model_dir", type=str, default=None, help=f"the path to save model files; uses {default_model_dir} by default") | |
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") | |
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") | |
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") | |
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") | |
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") | |
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") | |
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") | |
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") | |
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") | |
parser.add_argument("--patience", type=float, default=1, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") | |
parser.add_argument("--length_penalty", type=float, default=1, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") | |
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") | |
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") | |
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") | |
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") | |
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") | |
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") | |
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") | |
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") | |
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") | |
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") | |
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") | |
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") | |
# faster-whisper specific params | |
parser.add_argument("--vad_filter", type=str2bool, default=False, help="Enable vad filter") | |
parser.add_argument("--vad_min_silence_duration_ms", type=optional_int, default=None, help="Minimum silence duration in ms (default 2000ms)") | |
parser.add_argument("--vad_threshold", type=optional_float, default=None, help="Silero VAD outputs speech probabilities threshold, above which, audio chunk is considered speech (default 0.5)") | |
# forked faster-whisper specific | |
parser.add_argument("--sound_filter", type=optional_int, default=0, help="use 1 to fill pts gaps with silent and turn on aggresive speech normalization") | |
args = parser.parse_args() | |
model = model_map[args.model] | |
model_dir = os.path.expanduser(args.model_dir or default_model_dir) | |
model_path = os.path.join(model_dir, model) | |
# if cuda:1 -> device=cuda, device_index=1 | |
device = args.device | |
device_index = 0 | |
if device.startswith('cuda:'): | |
device, device_index = device.split(':') | |
device_index = int(device_index) | |
temperature = args.temperature | |
if (increment := args.temperature_increment_on_fallback) is not None: | |
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment)) | |
else: | |
temperature = [temperature] | |
logger = faster_whisper.utils.get_logger() | |
if args.verbose: | |
handler = logging.StreamHandler(sys.stdout) | |
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") | |
handler.setFormatter(formatter) | |
logger.setLevel(logging.DEBUG) | |
logger.addHandler(handler) | |
# Run on GPU with FP16 | |
model = faster_whisper.WhisperModel(model_path, device=device, device_index=device_index, compute_type="float16", cpu_threads=args.threads) | |
writer = get_writer(args.output_format, args.output_dir) | |
vad_parameters = {} | |
if args.vad_min_silence_duration_ms is not None: | |
vad_parameters['min_silence_duration_ms'] = args.vad_min_silence_duration_ms | |
if args.vad_threshold is not None: | |
vad_parameters['threshold'] = args.vad_threshold | |
if not vad_parameters: | |
vad_parameters = None | |
for audio_path in args.audio: | |
audio = audio_path | |
if args.sound_filter > 0: | |
audio = load_audio(audio_path, args.sound_filter) | |
segments, info = model.transcribe( | |
audio, language=args.language, task=args.task, beam_size=args.beam_size, best_of=args.best_of, patience=args.patience, length_penalty=args.length_penalty, | |
temperature=temperature, compression_ratio_threshold=args.compression_ratio_threshold, log_prob_threshold=args.logprob_threshold, | |
no_speech_threshold=args.no_speech_threshold, condition_on_previous_text=args.condition_on_previous_text, initial_prompt=args.initial_prompt, | |
word_timestamps=args.word_timestamps, prepend_punctuations=args.prepend_punctuations, append_punctuations=args.append_punctuations, | |
vad_filter=args.vad_filter, vad_parameters=vad_parameters, | |
) | |
all_segments = [] | |
for segment in segments: | |
line = f"[{format_timestamp(segment.start)} --> {format_timestamp(segment.end)}] {segment.text}" | |
print(make_safe(line)) | |
all_segments.append(segment) | |
writer({"segments": all_segments}, audio_path) | |
if __name__ == "__main__": | |
cli() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment