Skip to content

Instantly share code, notes, and snippets.

@dodysw
Last active May 27, 2023 15:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dodysw/e0af41d990def8bbeb8934eeabd77688 to your computer and use it in GitHub Desktop.
Save dodysw/e0af41d990def8bbeb8934eeabd77688 to your computer and use it in GitHub Desktop.
"""
This script contains adjusted cli codes from https://github.com/openai/whisper but uses https://github.com/guillaumekln/faster-whisper instead
Initialize models using this step:
mkdir -p ~/.cache/faster-whisper
# assuming a python venv installed on venv dir....
venv/bin/python -m pip install --force-reinstall "faster-whisper @ https://github.com/guillaumekln/faster-whisper/archive/refs/heads/master.tar.gz"
venv/bin/ct2-transformers-converter --model openai/whisper-large-v2 --output_dir ~/.cache/faster-whisper/whisper-large-v2-ct2 --copy_files tokenizer.json --quantization float16
venv/bin/ct2-transformers-converter --model openai/whisper-medium --output_dir ~/.cache/faster-whisper/whisper-medium-ct2 --copy_files tokenizer.json --quantization float16
# then use it like you used whisper script:
venv/bin/python fasterw.py audio.wav <same whisper parameters>
"""
import argparse
import json
import logging
import os.path
import sys
import typing
import faster_whisper
import ffmpeg
import numpy as np
import torch
default_model_dir = "~/.cache/faster-whisper"
model_map = {
'medium': 'whisper-medium-ct2',
'large': 'whisper-large-v2-ct2'
}
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
def available_models():
return model_map.keys()
def optional_int(string):
return None if string == "None" else int(string)
def optional_float(string):
return None if string == "None" else float(string)
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
def write_result(self, result: dict, file: typing.TextIO):
raise NotImplementedError
class WriteTXT(ResultWriter):
extension: str = "txt"
def write_result(self, result: dict, file: typing.TextIO):
for segment in result["segments"]:
print(segment.text.strip(), file=file, flush=True)
class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str
def iterate_result(self, result: dict):
for segment in result["segments"]:
segment_start = self.format_timestamp(segment.start)
segment_end = self.format_timestamp(segment.end)
segment_text = segment.text.strip().replace("-->", "->")
if word_timings := segment.words:
all_words = [timing.word for timing in word_timings]
all_words[0] = all_words[0].strip() # remove the leading space, if any
last = segment_start
for i, this_word in enumerate(word_timings):
start = self.format_timestamp(this_word.start)
end = self.format_timestamp(this_word.end)
if last != start:
yield last, start, segment_text
yield start, end, "".join(
[
f"<u>{word}</u>" if j == i else word
for j, word in enumerate(all_words)
]
)
last = end
if last != segment_end:
yield last, segment_end, segment_text
else:
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
return format_timestamp(
seconds=seconds,
always_include_hours=self.always_include_hours,
decimal_marker=self.decimal_marker,
)
class WriteVTT(SubtitlesWriter):
extension: str = "vtt"
always_include_hours: bool = False
decimal_marker: str = "."
def write_result(self, result: dict, file: typing.TextIO):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteSRT(SubtitlesWriter):
extension: str = "srt"
always_include_hours: bool = True
decimal_marker: str = ","
def write_result(self, result: dict, file: typing.TextIO):
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteTSV(ResultWriter):
"""
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"
def write_result(self, result: dict, file: typing.TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment.start), file=file, end="\t")
print(round(1000 * segment.end), file=file, end="\t")
print(segment.text.strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(self, result: dict, file: typing.TextIO):
json.dump(result, file)
def get_writer(output_format: str, output_dir: str) -> typing.Callable[[dict, typing.TextIO], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
# "json": WriteJSON,
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: typing.TextIO):
for writer in all_writers:
writer(result, file)
return write_all
return writers[output_format](output_dir)
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
def make_safe(string):
# replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding)
else:
def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string
def load_audio(file: str, sound_filter):
"""
Adjusted from openai whisper code, this code performs additional works:
1) filter to fill in ptc time gap with silents, up to 3600 seconds. Input file must be container that support storing ptc like mp4, mkv, .ts.
2) normalize volume to increase speech recognition of soft voice / too far from microphone
3) (as is) down sample to 16kbps bit rate, mono, as expected by the model encoder
"""
try:
ffmpeg_params = ffmpeg.input(file, threads=0) \
.audio \
.filter("aresample", **{"async": 3600})
if sound_filter == 1:
ffmpeg_params = ffmpeg_params.filter("speechnorm", **{"e": 50, "r": 0.0001, "l": 1})
elif sound_filter == 2:
ffmpeg_params = ffmpeg_params.filter("loudnorm")
ffmpeg_params = ffmpeg_params.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
print('sound filering using:', ' '.join(ffmpeg_params.get_args()))
out, _ = ffmpeg_params.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def cli():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="medium", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help=f"the path to save model files; uses {default_model_dir} by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=1, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=1, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# faster-whisper specific params
parser.add_argument("--vad_filter", type=str2bool, default=False, help="Enable vad filter")
parser.add_argument("--vad_min_silence_duration_ms", type=optional_int, default=None, help="Minimum silence duration in ms (default 2000ms)")
parser.add_argument("--vad_threshold", type=optional_float, default=None, help="Silero VAD outputs speech probabilities threshold, above which, audio chunk is considered speech (default 0.5)")
# forked faster-whisper specific
parser.add_argument("--sound_filter", type=optional_int, default=0, help="use 1 to fill pts gaps with silent and turn on aggresive speech normalization")
args = parser.parse_args()
model = model_map[args.model]
model_dir = os.path.expanduser(args.model_dir or default_model_dir)
model_path = os.path.join(model_dir, model)
# if cuda:1 -> device=cuda, device_index=1
device = args.device
device_index = 0
if device.startswith('cuda:'):
device, device_index = device.split(':')
device_index = int(device_index)
temperature = args.temperature
if (increment := args.temperature_increment_on_fallback) is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
else:
temperature = [temperature]
logger = faster_whisper.utils.get_logger()
if args.verbose:
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.setLevel(logging.DEBUG)
logger.addHandler(handler)
# Run on GPU with FP16
model = faster_whisper.WhisperModel(model_path, device=device, device_index=device_index, compute_type="float16", cpu_threads=args.threads)
writer = get_writer(args.output_format, args.output_dir)
vad_parameters = {}
if args.vad_min_silence_duration_ms is not None:
vad_parameters['min_silence_duration_ms'] = args.vad_min_silence_duration_ms
if args.vad_threshold is not None:
vad_parameters['threshold'] = args.vad_threshold
if not vad_parameters:
vad_parameters = None
for audio_path in args.audio:
audio = audio_path
if args.sound_filter > 0:
audio = load_audio(audio_path, args.sound_filter)
segments, info = model.transcribe(
audio, language=args.language, task=args.task, beam_size=args.beam_size, best_of=args.best_of, patience=args.patience, length_penalty=args.length_penalty,
temperature=temperature, compression_ratio_threshold=args.compression_ratio_threshold, log_prob_threshold=args.logprob_threshold,
no_speech_threshold=args.no_speech_threshold, condition_on_previous_text=args.condition_on_previous_text, initial_prompt=args.initial_prompt,
word_timestamps=args.word_timestamps, prepend_punctuations=args.prepend_punctuations, append_punctuations=args.append_punctuations,
vad_filter=args.vad_filter, vad_parameters=vad_parameters,
)
all_segments = []
for segment in segments:
line = f"[{format_timestamp(segment.start)} --> {format_timestamp(segment.end)}] {segment.text}"
print(make_safe(line))
all_segments.append(segment)
writer({"segments": all_segments}, audio_path)
if __name__ == "__main__":
cli()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment