Skip to content

Instantly share code, notes, and snippets.

@alexlyzhov
Created January 24, 2023 04:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexlyzhov/72f194f4512f9abfba5f95617f81aac4 to your computer and use it in GitHub Desktop.
Save alexlyzhov/72f194f4512f9abfba5f95617f81aac4 to your computer and use it in GitHub Desktop.
Whisper to json
# based on https://github.com/ANonEntity/WhisperWithVAD
import torch
import whisper
import os
import ffmpeg
import srt
from tqdm import tqdm
import datetime
import urllib.request
import json
from glob import glob
model_size = 'medium'
language = "english" # @param {type:"string"}
translation_mode = "End-to-end Whisper (default)" # @param ["End-to-end Whisper (default)", "Whisper -> DeepL", "No translation"]
# @markdown Advanced settings:
deepl_authkey = "" # @param {type:"string"}
chunk_threshold = 3.0 # @param {type:"number"}
max_attempts = 1 # @param {type:"integer"}
# Configuration
assert max_attempts >= 1
assert chunk_threshold >= 0.1
# assert audio_path != ""
assert language != ""
if translation_mode == "End-to-end Whisper (default)":
task = "translate"
run_deepl = False
elif translation_mode == "Whisper -> DeepL":
task = "transcribe"
run_deepl = True
elif translation_mode == "No translation":
task = "transcribe"
run_deepl = False
else:
raise ValueError("Invalid translation mode")
inputs = glob('/Users/alexlyzhov/Documents/recordings/*.mp3') + glob('/Users/alexlyzhov/Documents/recordings/*.wav')
tmp_path = '/Users/alexlyzhov/Documents/recordings/tmp/vad_chunks'
todo_inputs = []
for input in inputs:
wo_ext = os.path.splitext(input)[0]
srt_file = wo_ext + '.srt'
vtt_file = wo_ext + '.vtt'
txt_file = wo_ext + '.txt'
add_srt_file = input + '.srt'
add_vtt_file = input + '.vtt'
add_txt_file = input + '.txt'
todo = not (os.path.exists(srt_file) or os.path.exists(vtt_file) or os.path.exists(add_srt_file) or os.path.exists(add_vtt_file)
or os.path.exists(txt_file) or os.path.exists(add_txt_file))
# print(srt_file, vtt_file, add_srt_file, add_vtt_file, todo)
if todo:
todo_inputs.append(input)
print(todo_inputs)
def encode(audio_path):
print("Encoding audio...")
if not os.path.exists(tmp_path):
os.mkdir(tmp_path)
ffmpeg.input(audio_path).output(
os.path.join(tmp_path, "silero_temp.wav"),
ar="16000",
ac="1",
acodec="pcm_s16le",
map_metadata="-1",
fflags="+bitexact",
).overwrite_output().run(quiet=True)
for audio_path in todo_inputs:
encode(audio_path)
print("Running VAD...")
model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad", onnx=False
)
(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils
# Generate VAD timestamps
VAD_SR = 16000
wav = read_audio(os.path.join(tmp_path, "silero_temp.wav"), sampling_rate=VAD_SR)
t = get_speech_timestamps(wav, model, sampling_rate=VAD_SR)
# Add a bit of padding, and remove small gaps
for i in range(len(t)):
t[i]["start"] = max(0, t[i]["start"] - 3200) # 0.2s head
t[i]["end"] = min(wav.shape[0] - 16, t[i]["end"] + 20800) # 1.3s tail
if i > 0 and t[i]["start"] < t[i - 1]["end"]:
t[i]["start"] = t[i - 1]["end"] # Remove overlap
# If breaks are longer than chunk_threshold seconds, split into a new audio file
# This'll effectively turn long transcriptions into many shorter ones
u = [[]]
for i in range(len(t)):
if i > 0 and t[i]["start"] > t[i - 1]["end"] + (chunk_threshold * VAD_SR):
u.append([])
u[-1].append(t[i])
# Merge speech chunks
for i in range(len(u)):
save_audio(
os.path.join(tmp_path, str(i) + ".wav"),
collect_chunks(u[i], wav),
sampling_rate=VAD_SR,
)
os.remove(os.path.join(tmp_path, "silero_temp.wav"))
# Convert timestamps to seconds
for i in range(len(u)):
time = 0.0
offset = 0.0
for j in range(len(u[i])):
u[i][j]["start"] /= VAD_SR
u[i][j]["end"] /= VAD_SR
u[i][j]["chunk_start"] = time
time += u[i][j]["end"] - u[i][j]["start"]
u[i][j]["chunk_end"] = time
if j == 0:
offset += u[i][j]["start"]
else:
offset += u[i][j]["start"] - u[i][j - 1]["end"]
u[i][j]["offset"] = offset
# Run Whisper on each audio chunk
print("Running Whisper...")
model = whisper.load_model(model_size)
subs = []
segment_info = []
sub_index = 1
suppress_low = [
"Thank you",
"Thanks for",
"ike and ",
"Bye.",
"Bye!",
"Bye bye!",
"lease sub",
"The end.",
"視聴",
]
suppress_high = [
"ubscribe",
"my channel",
"the channel",
"our channel",
"ollow me on",
"for watching",
"hank you for watching",
"for your viewing",
"r viewing",
"Amara",
"next video",
"full video",
"ranslation by",
"ranslated by",
"ee you next week",
"ご視聴",
"視聴ありがとうございました",
]
for i in tqdm(range(len(u))):
line_buffer = [] # Used for DeepL
for x in range(max_attempts):
result = model.transcribe(
os.path.join(tmp_path, str(i) + ".wav"), task=task, language=language
)
# Break if result doesn't end with severe hallucinations
if len(result["segments"]) == 0:
break
elif result["segments"][-1]["end"] < u[i][-1]["chunk_end"] + 10.0:
break
elif x+1 < max_attempts:
print("Retrying chunk", i)
for r in result["segments"]:
# Skip audio timestamped after the chunk has ended
if r["start"] > u[i][-1]["chunk_end"]:
continue
# Reduce log probability for certain words/phrases
for s in suppress_low:
if s in r["text"]:
r["avg_logprob"] -= 0.15
for s in suppress_high:
if s in r["text"]:
r["avg_logprob"] -= 0.35
# Keep segment info for debugging
del r["tokens"]
segment_info.append(r)
# Skip if log prob is low or no speech prob is high
if r["avg_logprob"] < -1.0 or r["no_speech_prob"] > 0.7:
continue
# Set start timestamp
start = r["start"] + u[i][0]["offset"]
for j in range(len(u[i])):
if (
r["start"] >= u[i][j]["chunk_start"]
and r["start"] <= u[i][j]["chunk_end"]
):
start = r["start"] + u[i][j]["offset"]
break
# Prevent overlapping subs
if len(subs) > 0:
last_end = datetime.timedelta.total_seconds(subs[-1].end)
if last_end > start:
subs[-1].end = datetime.timedelta(seconds=start)
# Set end timestamp
end = u[i][-1]["end"] + 0.5
for j in range(len(u[i])):
if r["end"] >= u[i][j]["chunk_start"] and r["end"] <= u[i][j]["chunk_end"]:
end = r["end"] + u[i][j]["offset"]
break
# Add to SRT list
subs.append(
srt.Subtitle(
index=sub_index,
start=datetime.timedelta(seconds=start),
end=datetime.timedelta(seconds=end),
content=r["text"].strip(),
)
)
sub_index += 1
with open("segment_info.json", "w", encoding="utf8") as f:
json.dump(segment_info, f, indent=4)
out_path = os.path.splitext(audio_path)[0] + ".srt"
with open(out_path, "w", encoding="utf8") as f:
f.write(srt.compose(subs))
print("\nDone! Subs written to", out_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment