Created
December 9, 2023 15:34
-
-
Save vanatteveldt/6a3bd47dd6216c087d39a2decc541a0e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import sys | |
import numpy | |
from pyannote.audio.pipelines.utils.hook import ProgressHook | |
import collections | |
import whisper | |
from pyannote.audio import Pipeline | |
import torch | |
from pyannote.audio import Audio | |
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding | |
from scipy.spatial.distance import cdist | |
input = "eenvandaagdebate.wav" | |
output = input.replace(".wav", ".csv") | |
print(f"Transcribing {input} to {output}") | |
print("Loading models") | |
pipeline = Pipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", | |
use_auth_token="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", | |
) | |
embedding = PretrainedSpeakerEmbedding( | |
"speechbrain/spkrec-ecapa-voxceleb", device=torch.device("cuda") | |
) | |
model = whisper.load_model("large") | |
pipeline.to(torch.device("cuda")) | |
audio = Audio(sample_rate=16000, mono="downmix") | |
print("Getting reference embeddings") | |
# Define all speakers and their corresponding files | |
speakers_data = { | |
"jetten": "jetten.wav", | |
"timmermans": "timmermans.wav", | |
"omtzigt": "omtzigt.wav", | |
"vanderplas": "vanderplas.wav", | |
"wilders": "wilders.wav", | |
"yesilgoz": "yesilgoz.wav", | |
} | |
# Load embeddings for all speakers | |
speaker_embeddings = {} | |
for name, filename in speakers_data.items(): | |
print(name) | |
speaker_embeddings[name] = embedding(audio(filename)[0][None]) | |
torch.cuda.empty_cache() | |
print("Diarization") | |
with ProgressHook() as hook: | |
diarization = pipeline(input, hook=hook) | |
spreekbeurten = [ | |
(turn, speaker) for (turn, _, speaker) in diarization.itertracks(yield_label=True) | |
] | |
beurten_per_spreker = collections.defaultdict(list) | |
for turn, speaker in spreekbeurten: | |
beurten_per_spreker[speaker].append(turn) | |
speaker_names = {} | |
mean = lambda values: sum(values) / len(values) | |
for i, (speaker, turns) in enumerate(beurten_per_spreker.items()): | |
print( | |
f"Speaker {i}/{len(beurten_per_spreker)}: {speaker} - Calculating embeddings for {len(turns)} segments" | |
) | |
embeddings = [ | |
embedding(audio.crop(input, turn)[0][None]) | |
for turn in turns | |
if (turn.end - turn.start) > 0.5 | |
] | |
if not embeddings: | |
speaker_names[speaker] = "-", 1 | |
guesses = {} | |
for speaker_name, reference_embedding in speaker_embeddings.items(): | |
distances = [ | |
cdist(e, reference_embedding, metric="cosine")[0][0] | |
for e in embeddings | |
if not numpy.any([numpy.isnan(el) for el in e]) | |
] | |
guesses[speaker_name] = mean(distances) | |
guesses = sorted(guesses.items(), key=lambda item: item[1]) | |
name, dist = guesses[0] | |
speaker_names[speaker] = name, dist | |
print(f"Best guess: {name}, score: {dist:.3f}") | |
# (identifies speakers based on minimum distances) | |
print("Whispering") | |
prompt = "Het debat gaat tussen de leiders van zes Nederlandse prominente politieke partijen. Dilan Yesilgöz (VVD), Frans Timmermans (GroenLinks/PvdA), Geert Wilders (PVV), Rob Jetten (D66), Pieter Omtzigt (NSC) en Caroline van der Plas (BBB). Deze politieke leiders zullen discussiëren over onderwerpen als immigratie, klimaatverandering, economie en sociaal-economische zekerheid." | |
result = model.transcribe(input, language="nl", initial_prompt=prompt) | |
print("Result") | |
def get_speaker(start, end): | |
max_d = 0 | |
best_speaker = None, None | |
for turn, speaker in spreekbeurten: | |
if turn.start > end: | |
break | |
if turn.end < start: | |
continue | |
d = min(turn.end, end) - max(turn.start, start) | |
if d > max_d: | |
max_d = d | |
best_speaker = (turn, speaker) | |
return best_speaker | |
w = csv.writer(open(output, "w")) | |
w.writerow( | |
[ | |
"segment_start", | |
"segment_end", | |
"turn_start", | |
"turn_end", | |
"speakernum", | |
"speakername", | |
"name_confidence", | |
"text", | |
] | |
) | |
for segment in result["segments"]: | |
turn, speaker = get_speaker(segment["start"], segment["end"]) | |
name, dist = speaker_names[speaker] | |
w.writerow( | |
[ | |
segment["start"], | |
segment["end"], | |
turn.start, | |
turn.end, | |
speaker, | |
name, | |
1 - dist, | |
segment["text"], | |
] | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment