Last active
September 17, 2023 17:15
-
-
Save vadimkantorov/1bef26d20cafa648b665bff89f7a0fcb to your computer and use it in GitHub Desktop.
Barcode-like HTML visualization for speaker diarization / VAD outputs (supports at most two speakers + silence)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
speaker_colors = ['gray', 'violet', 'lightblue'] | |
ref_missing = '' | |
speaker_name_missing = '' | |
speaker_missing = 0 | |
speaker_phrase_separator = ';' | |
speaker_separator = ', ' | |
channel_missing = -1 | |
time_missing = -1 | |
_er_missing = -1.0 | |
default_speaker_names = '_' + ''.join(chr(ord('A') + i) for i in range(26)) | |
default_channel_names = {channel_missing : 'channel_', 0 : 'channel0', 1 : 'channel1'} | |
def sort_key(t): | |
return t.get('audio_path'), t.get('begin'), t.get('end'), t.get('channel') | |
def group_key(t): | |
return t.get('audio_path') | |
def speaker_name(ref = None, hyp = None): | |
return speaker_separator.join(sorted(filter(bool, set(t.get('speaker_name') for t in ref + hyp)))) or None | |
def segment_by_time(transcript, max_segment_seconds, break_on_speaker_change = True, break_on_channel_change = True): | |
transcript = [t for t in transcript if t['begin'] != time_missing and t['end'] != time_missing] | |
ind_last_taken = -1 | |
for j, t in enumerate(transcript): | |
first, last = ind_last_taken == -1, j == len(transcript) - 1 | |
if last or (t['end'] - transcript[ind_last_taken + 1]['begin'] > max_segment_seconds) \ | |
or (break_on_speaker_change and j >= 1 and t['speaker'] != transcript[j - 1]['speaker']) \ | |
or (break_on_channel_change and j >= 1 and t['channel'] != transcript[j - 1]['channel']): | |
ind_last_taken, transcript_segment = take_between(transcript, ind_last_taken, t, first, last, sort_by_time=False) | |
if transcript_segment: | |
yield transcript_segment | |
def take_between(transcript, ind_last_taken, t, first, last, sort_by_time = True, soft = True, set_speaker = False): | |
if sort_by_time: | |
lt = lambda a, b: a['end'] < b['begin'] | |
gt = lambda a, b: a['end'] > b['begin'] | |
else: | |
lt = lambda a, b: sort_key(a) < sort_key(b) | |
gt = lambda a, b: sort_key(a) > sort_key(b) | |
if soft: | |
res = [(k, u) for k, u in enumerate(transcript) if (first or ind_last_taken < 0 or lt(transcript[ind_last_taken], u)) and (last or gt(t, u))] | |
else: | |
intersects = lambda t, begin, end: (begin <= t['end'] and t['begin'] <= end) | |
res = [(k, u) for k, u in enumerate(transcript) if ind_last_taken < k and intersects(t, u['begin'], u['end'])] if t else [] | |
ind_last_taken, transcript = zip(*res) if res else ([ind_last_taken], []) | |
if set_speaker: | |
for u in transcript: | |
u['speaker'] = t.get('speaker', speaker_missing) | |
if t.get('speaker_name') is not None: | |
u['speaker_name'] = t['speaker_name'] | |
return ind_last_taken[-1], list(transcript) | |
onclick_svg_script = ''' | |
function onclick_svg(evt) | |
{ | |
const rect = evt.target; | |
const channel = rect.dataset.channel || 0; | |
play(evt, channel, parseFloat(rect.dataset.begin), parseFloat(rect.dataset.end)); | |
} | |
''' | |
def fmt_svg_speaker_barcode(transcript, begin, end, colors = speaker_colors, max_segment_seconds = 60, onclick = None): | |
if onclick is None: | |
onclick = 'onclick_svg(event)' | |
color = lambda s: colors[s] if s < len(colors) else transcripts.speaker_missing | |
html = '' | |
segments = transcripts.segment_by_time(transcript, max_segment_seconds = max_segment_seconds, break_on_speaker_change = False, break_on_channel_change = False) | |
for segment in segments: | |
summary = transcripts.summary(segment) | |
duration = transcripts.compute_duration(summary) | |
if duration <= max_segment_seconds: | |
duration = max_segment_seconds | |
header = '<div style="width: 100%; height: 15px; border: 1px black solid"><svg viewbox="0 0 1 1" style="width:100%; height:100%" preserveAspectRatio="none">' | |
body = '\n'.join('<rect data-begin="{begin}" data-end="{end}" x="{x}" width="{width}" height="1" style="fill:{color}" onclick="{onclick}"><title>speaker{speaker} | {begin:.2f} - {end:.2f} [{duration:.2f}]</title></rect>'.format(onclick = onclick, x = (t['begin'] - summary['begin']) / duration, width = (t['end'] - t['begin']) / duration, color = color(t['speaker']), duration = transcripts.compute_duration(t), **t) for t in transcript) | |
footer = '</svg></div>' | |
html += header + body + footer | |
return html | |
def diarization(diarization_transcript, html_path, debug_audio): | |
with open(html_path, 'w') as html: | |
html.write('<html><head>' + meta_charset + '<style>.nowrap{white-space:nowrap} table {border-collapse:collapse} .border-hyp {border-bottom: 2px black solid}</style></head><body>\n') | |
html.write(f'<script>{play_script}</script>\n') | |
html.write(f'<script>{onclick_img_script}</script>') | |
html.write('<table>\n') | |
html.write('<tr><th>audio_name</th><th>duration</th><th>refhyp</th><th>ser</th><th>der</th><th>der_</th><th>audio</th><th>barcode</th></tr>\n') | |
avg = lambda l: sum(l) / len(l) | |
html.write('<tr class="border-hyp"><td>{num_files}</td><td>{total_duration:.02f}</td><td>avg</td><td>{avg_ser:.02f}</td><td>{avg_der:.02f}</td><td>{avg_der_:.02f}</td><td></td><td></td></tr>\n'.format( | |
num_files = len(diarization_transcript), | |
total_duration = sum(map(transcripts.compute_duration, diarization_transcript)), | |
avg_ser = avg([t['ser'] for t in diarization_transcript]), | |
avg_der = avg([t['der'] for t in diarization_transcript]), | |
avg_der_ = avg([t['der_'] for t in diarization_transcript]) | |
)) | |
for i, dt in enumerate(diarization_transcript): | |
audio_html = fmt_audio(audio_path, channel = channel) if debug_audio else '' | |
begin, end = 0.0, transcripts.compute_duration(dt) | |
for refhyp in ['ref', 'hyp']: | |
html.write('<tr class="border-{refhyp}"><td class="nowrap">{audio_name}</td><td>{end:.02f}</td><td>{refhyp}</td><td>{ser:.02f}</td><td>{der:.02f}</td><td>{der_:.02f}</td><td rospan="{rowspan}">{audio_html}</td><td>{barcode}</td></tr>\n'.format(audio_name = dt['audio_name'], audio_html = audio_html if refhyp == 'ref' else '', rowspan = 2 if refhyp == 'ref' else 1, refhyp = refhyp, end = end, ser = dt['ser'], der = dt['der'], der_ = dt['der_'], barcode = fmt_img_speaker_barcode(dt[refhyp], begin = begin, end = end, onclick = None if debug_audio else '', dataset = dict(channel = i)))) | |
html.write('</table></body></html>') | |
return html_path | |
def fmt_img_speaker_barcode(transcript, begin = None, end = None, colors = speaker_colors, onclick = None, dataset = {}): | |
if begin is None: | |
begin = 0 | |
if end is None: | |
end = max(t['end'] for t in transcript) | |
if onclick is None: | |
onclick = 'onclick_img(event)' | |
color = lambda s: colors[s] if s < len(colors) else transcripts.speaker_missing | |
plt.figure(figsize = (8, 0.2)) | |
plt.xlim(begin, end) | |
plt.yticks([]) | |
plt.axis('off') | |
for t in transcript: | |
plt.axvspan(t['begin'], t['end'], color = color(t.get('speaker', transcripts.speaker_missing))) | |
plt.subplots_adjust(left = 0, right = 1, bottom = 0, top = 1) | |
buf = io.BytesIO() | |
plt.savefig(buf, format = 'jpg', dpi = 150, facecolor = color(transcripts.speaker_missing)) | |
plt.close() | |
uri_speaker_barcode = base64.b64encode(buf.getvalue()).decode() | |
dataset = ' '.join(f'data-{k}="{v}"' for k, v in dataset.items()) | |
return f'<img onclick="{onclick}" src="data:image/jpeg;base64,{uri_speaker_barcode}" style="width:100% {dataset}"></img>' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment