Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active September 17, 2023 17:15
Show Gist options
  • Save vadimkantorov/1bef26d20cafa648b665bff89f7a0fcb to your computer and use it in GitHub Desktop.
Save vadimkantorov/1bef26d20cafa648b665bff89f7a0fcb to your computer and use it in GitHub Desktop.
Barcode-like HTML visualization for speaker diarization / VAD outputs (supports at most two speakers + silence)
speaker_colors = ['gray', 'violet', 'lightblue']
ref_missing = ''
speaker_name_missing = ''
speaker_missing = 0
speaker_phrase_separator = ';'
speaker_separator = ', '
channel_missing = -1
time_missing = -1
_er_missing = -1.0
default_speaker_names = '_' + ''.join(chr(ord('A') + i) for i in range(26))
default_channel_names = {channel_missing : 'channel_', 0 : 'channel0', 1 : 'channel1'}
def sort_key(t):
return t.get('audio_path'), t.get('begin'), t.get('end'), t.get('channel')
def group_key(t):
return t.get('audio_path')
def speaker_name(ref = None, hyp = None):
return speaker_separator.join(sorted(filter(bool, set(t.get('speaker_name') for t in ref + hyp)))) or None
def segment_by_time(transcript, max_segment_seconds, break_on_speaker_change = True, break_on_channel_change = True):
transcript = [t for t in transcript if t['begin'] != time_missing and t['end'] != time_missing]
ind_last_taken = -1
for j, t in enumerate(transcript):
first, last = ind_last_taken == -1, j == len(transcript) - 1
if last or (t['end'] - transcript[ind_last_taken + 1]['begin'] > max_segment_seconds) \
or (break_on_speaker_change and j >= 1 and t['speaker'] != transcript[j - 1]['speaker']) \
or (break_on_channel_change and j >= 1 and t['channel'] != transcript[j - 1]['channel']):
ind_last_taken, transcript_segment = take_between(transcript, ind_last_taken, t, first, last, sort_by_time=False)
if transcript_segment:
yield transcript_segment
def take_between(transcript, ind_last_taken, t, first, last, sort_by_time = True, soft = True, set_speaker = False):
if sort_by_time:
lt = lambda a, b: a['end'] < b['begin']
gt = lambda a, b: a['end'] > b['begin']
else:
lt = lambda a, b: sort_key(a) < sort_key(b)
gt = lambda a, b: sort_key(a) > sort_key(b)
if soft:
res = [(k, u) for k, u in enumerate(transcript) if (first or ind_last_taken < 0 or lt(transcript[ind_last_taken], u)) and (last or gt(t, u))]
else:
intersects = lambda t, begin, end: (begin <= t['end'] and t['begin'] <= end)
res = [(k, u) for k, u in enumerate(transcript) if ind_last_taken < k and intersects(t, u['begin'], u['end'])] if t else []
ind_last_taken, transcript = zip(*res) if res else ([ind_last_taken], [])
if set_speaker:
for u in transcript:
u['speaker'] = t.get('speaker', speaker_missing)
if t.get('speaker_name') is not None:
u['speaker_name'] = t['speaker_name']
return ind_last_taken[-1], list(transcript)
onclick_svg_script = '''
function onclick_svg(evt)
{
const rect = evt.target;
const channel = rect.dataset.channel || 0;
play(evt, channel, parseFloat(rect.dataset.begin), parseFloat(rect.dataset.end));
}
'''
def fmt_svg_speaker_barcode(transcript, begin, end, colors = speaker_colors, max_segment_seconds = 60, onclick = None):
if onclick is None:
onclick = 'onclick_svg(event)'
color = lambda s: colors[s] if s < len(colors) else transcripts.speaker_missing
html = ''
segments = transcripts.segment_by_time(transcript, max_segment_seconds = max_segment_seconds, break_on_speaker_change = False, break_on_channel_change = False)
for segment in segments:
summary = transcripts.summary(segment)
duration = transcripts.compute_duration(summary)
if duration <= max_segment_seconds:
duration = max_segment_seconds
header = '<div style="width: 100%; height: 15px; border: 1px black solid"><svg viewbox="0 0 1 1" style="width:100%; height:100%" preserveAspectRatio="none">'
body = '\n'.join('<rect data-begin="{begin}" data-end="{end}" x="{x}" width="{width}" height="1" style="fill:{color}" onclick="{onclick}"><title>speaker{speaker} | {begin:.2f} - {end:.2f} [{duration:.2f}]</title></rect>'.format(onclick = onclick, x = (t['begin'] - summary['begin']) / duration, width = (t['end'] - t['begin']) / duration, color = color(t['speaker']), duration = transcripts.compute_duration(t), **t) for t in transcript)
footer = '</svg></div>'
html += header + body + footer
return html
def diarization(diarization_transcript, html_path, debug_audio):
with open(html_path, 'w') as html:
html.write('<html><head>' + meta_charset + '<style>.nowrap{white-space:nowrap} table {border-collapse:collapse} .border-hyp {border-bottom: 2px black solid}</style></head><body>\n')
html.write(f'<script>{play_script}</script>\n')
html.write(f'<script>{onclick_img_script}</script>')
html.write('<table>\n')
html.write('<tr><th>audio_name</th><th>duration</th><th>refhyp</th><th>ser</th><th>der</th><th>der_</th><th>audio</th><th>barcode</th></tr>\n')
avg = lambda l: sum(l) / len(l)
html.write('<tr class="border-hyp"><td>{num_files}</td><td>{total_duration:.02f}</td><td>avg</td><td>{avg_ser:.02f}</td><td>{avg_der:.02f}</td><td>{avg_der_:.02f}</td><td></td><td></td></tr>\n'.format(
num_files = len(diarization_transcript),
total_duration = sum(map(transcripts.compute_duration, diarization_transcript)),
avg_ser = avg([t['ser'] for t in diarization_transcript]),
avg_der = avg([t['der'] for t in diarization_transcript]),
avg_der_ = avg([t['der_'] for t in diarization_transcript])
))
for i, dt in enumerate(diarization_transcript):
audio_html = fmt_audio(audio_path, channel = channel) if debug_audio else ''
begin, end = 0.0, transcripts.compute_duration(dt)
for refhyp in ['ref', 'hyp']:
html.write('<tr class="border-{refhyp}"><td class="nowrap">{audio_name}</td><td>{end:.02f}</td><td>{refhyp}</td><td>{ser:.02f}</td><td>{der:.02f}</td><td>{der_:.02f}</td><td rospan="{rowspan}">{audio_html}</td><td>{barcode}</td></tr>\n'.format(audio_name = dt['audio_name'], audio_html = audio_html if refhyp == 'ref' else '', rowspan = 2 if refhyp == 'ref' else 1, refhyp = refhyp, end = end, ser = dt['ser'], der = dt['der'], der_ = dt['der_'], barcode = fmt_img_speaker_barcode(dt[refhyp], begin = begin, end = end, onclick = None if debug_audio else '', dataset = dict(channel = i))))
html.write('</table></body></html>')
return html_path
def fmt_img_speaker_barcode(transcript, begin = None, end = None, colors = speaker_colors, onclick = None, dataset = {}):
if begin is None:
begin = 0
if end is None:
end = max(t['end'] for t in transcript)
if onclick is None:
onclick = 'onclick_img(event)'
color = lambda s: colors[s] if s < len(colors) else transcripts.speaker_missing
plt.figure(figsize = (8, 0.2))
plt.xlim(begin, end)
plt.yticks([])
plt.axis('off')
for t in transcript:
plt.axvspan(t['begin'], t['end'], color = color(t.get('speaker', transcripts.speaker_missing)))
plt.subplots_adjust(left = 0, right = 1, bottom = 0, top = 1)
buf = io.BytesIO()
plt.savefig(buf, format = 'jpg', dpi = 150, facecolor = color(transcripts.speaker_missing))
plt.close()
uri_speaker_barcode = base64.b64encode(buf.getvalue()).decode()
dataset = ' '.join(f'data-{k}="{v}"' for k, v in dataset.items())
return f'<img onclick="{onclick}" src="data:image/jpeg;base64,{uri_speaker_barcode}" style="width:100% {dataset}"></img>'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment