Skip to content

Instantly share code, notes, and snippets.

@vanatteveldt
Created December 9, 2023 15:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vanatteveldt/6a3bd47dd6216c087d39a2decc541a0e to your computer and use it in GitHub Desktop.
Save vanatteveldt/6a3bd47dd6216c087d39a2decc541a0e to your computer and use it in GitHub Desktop.
import csv
import sys
import numpy
from pyannote.audio.pipelines.utils.hook import ProgressHook
import collections
import whisper
from pyannote.audio import Pipeline
import torch
from pyannote.audio import Audio
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from scipy.spatial.distance import cdist
input = "eenvandaagdebate.wav"
output = input.replace(".wav", ".csv")
print(f"Transcribing {input} to {output}")
print("Loading models")
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token="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",
)
embedding = PretrainedSpeakerEmbedding(
"speechbrain/spkrec-ecapa-voxceleb", device=torch.device("cuda")
)
model = whisper.load_model("large")
pipeline.to(torch.device("cuda"))
audio = Audio(sample_rate=16000, mono="downmix")
print("Getting reference embeddings")
# Define all speakers and their corresponding files
speakers_data = {
"jetten": "jetten.wav",
"timmermans": "timmermans.wav",
"omtzigt": "omtzigt.wav",
"vanderplas": "vanderplas.wav",
"wilders": "wilders.wav",
"yesilgoz": "yesilgoz.wav",
}
# Load embeddings for all speakers
speaker_embeddings = {}
for name, filename in speakers_data.items():
print(name)
speaker_embeddings[name] = embedding(audio(filename)[0][None])
torch.cuda.empty_cache()
print("Diarization")
with ProgressHook() as hook:
diarization = pipeline(input, hook=hook)
spreekbeurten = [
(turn, speaker) for (turn, _, speaker) in diarization.itertracks(yield_label=True)
]
beurten_per_spreker = collections.defaultdict(list)
for turn, speaker in spreekbeurten:
beurten_per_spreker[speaker].append(turn)
speaker_names = {}
mean = lambda values: sum(values) / len(values)
for i, (speaker, turns) in enumerate(beurten_per_spreker.items()):
print(
f"Speaker {i}/{len(beurten_per_spreker)}: {speaker} - Calculating embeddings for {len(turns)} segments"
)
embeddings = [
embedding(audio.crop(input, turn)[0][None])
for turn in turns
if (turn.end - turn.start) > 0.5
]
if not embeddings:
speaker_names[speaker] = "-", 1
guesses = {}
for speaker_name, reference_embedding in speaker_embeddings.items():
distances = [
cdist(e, reference_embedding, metric="cosine")[0][0]
for e in embeddings
if not numpy.any([numpy.isnan(el) for el in e])
]
guesses[speaker_name] = mean(distances)
guesses = sorted(guesses.items(), key=lambda item: item[1])
name, dist = guesses[0]
speaker_names[speaker] = name, dist
print(f"Best guess: {name}, score: {dist:.3f}")
# (identifies speakers based on minimum distances)
print("Whispering")
prompt = "Het debat gaat tussen de leiders van zes Nederlandse prominente politieke partijen. Dilan Yesilgöz (VVD), Frans Timmermans (GroenLinks/PvdA), Geert Wilders (PVV), Rob Jetten (D66), Pieter Omtzigt (NSC) en Caroline van der Plas (BBB). Deze politieke leiders zullen discussiëren over onderwerpen als immigratie, klimaatverandering, economie en sociaal-economische zekerheid."
result = model.transcribe(input, language="nl", initial_prompt=prompt)
print("Result")
def get_speaker(start, end):
max_d = 0
best_speaker = None, None
for turn, speaker in spreekbeurten:
if turn.start > end:
break
if turn.end < start:
continue
d = min(turn.end, end) - max(turn.start, start)
if d > max_d:
max_d = d
best_speaker = (turn, speaker)
return best_speaker
w = csv.writer(open(output, "w"))
w.writerow(
[
"segment_start",
"segment_end",
"turn_start",
"turn_end",
"speakernum",
"speakername",
"name_confidence",
"text",
]
)
for segment in result["segments"]:
turn, speaker = get_speaker(segment["start"], segment["end"])
name, dist = speaker_names[speaker]
w.writerow(
[
segment["start"],
segment["end"],
turn.start,
turn.end,
speaker,
name,
1 - dist,
segment["text"],
]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment