Skip to content

Instantly share code, notes, and snippets.

@jimregan
Last active April 3, 2021 00:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jimregan/e2b7666c5c6397d7939b0ca97354f314 to your computer and use it in GitHub Desktop.
Save jimregan/e2b7666c5c6397d7939b0ca97354f314 to your computer and use it in GitHub Desktop.
# VAD wrapper is taken from PyTorch Speaker Verification:
# https://github.com/HarryVolek/PyTorch_Speaker_Verification
# Copyright (c) 2019, HarryVolek
# License: BSD-3-Clause
# based on https://github.com/wiseman/py-webrtcvad/blob/master/example.py
# Copyright (c) 2016 John Wiseman
# License: MIT
# Additions Copyright (c) 2021, Jim O'Regan
# License: MIT
import collections
import contextlib
import numpy as np
import sys
import librosa
import wave
import webrtcvad
#from hparam import hparam as hp
sr = 16000
def read_wave(path, sr):
"""Reads a .wav file.
Takes the path, and returns (PCM audio data, sample rate).
Assumes sample width == 2
"""
with contextlib.closing(wave.open(path, 'rb')) as wf:
num_channels = wf.getnchannels()
assert num_channels == 1
sample_width = wf.getsampwidth()
assert sample_width == 2
sample_rate = wf.getframerate()
assert sample_rate in (8000, 16000, 32000, 48000)
pcm_data = wf.readframes(wf.getnframes())
data, _ = librosa.load(path, sr)
assert len(data.shape) == 1
assert sr in (8000, 16000, 32000, 48000)
return data, pcm_data
class Frame(object):
"""Represents a "frame" of audio data."""
def __init__(self, bytes, timestamp, duration):
self.bytes = bytes
self.timestamp = timestamp
self.duration = duration
def frame_generator(frame_duration_ms, audio, sample_rate):
"""Generates audio frames from PCM audio data.
Takes the desired frame duration in milliseconds, the PCM data, and
the sample rate.
Yields Frames of the requested duration.
"""
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
offset = 0
timestamp = 0.0
duration = (float(n) / sample_rate) / 2.0
while offset + n < len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
timestamp += duration
offset += n
def vad_collector(sample_rate, frame_duration_ms,
padding_duration_ms, vad, frames):
"""Filters out non-voiced audio frames.
Given a webrtcvad.Vad and a source of audio frames, yields only
the voiced audio.
Uses a padded, sliding window algorithm over the audio frames.
When more than 90% of the frames in the window are voiced (as
reported by the VAD), the collector triggers and begins yielding
audio frames. Then the collector waits until 90% of the frames in
the window are unvoiced to detrigger.
The window is padded at the front and back to provide a small
amount of silence or the beginnings/endings of speech around the
voiced frames.
Arguments:
sample_rate - The audio sample rate, in Hz.
frame_duration_ms - The frame duration in milliseconds.
padding_duration_ms - The amount to pad the window, in milliseconds.
vad - An instance of webrtcvad.Vad.
frames - a source of audio frames (sequence or generator).
Returns: A generator that yields PCM audio data.
"""
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
# We use a deque for our sliding window/ring buffer.
ring_buffer = collections.deque(maxlen=num_padding_frames)
# We have two states: TRIGGERED and NOTTRIGGERED. We start in the
# NOTTRIGGERED state.
triggered = False
voiced_frames = []
for frame in frames:
is_speech = vad.is_speech(frame.bytes, sample_rate)
if not triggered:
ring_buffer.append((frame, is_speech))
num_voiced = len([f for f, speech in ring_buffer if speech])
# If we're NOTTRIGGERED and more than 90% of the frames in
# the ring buffer are voiced frames, then enter the
# TRIGGERED state.
if num_voiced > 0.9 * ring_buffer.maxlen:
triggered = True
start = ring_buffer[0][0].timestamp
# We want to yield all the audio we see from now until
# we are NOTTRIGGERED, but we have to start with the
# audio that's already in the ring buffer.
for f, s in ring_buffer:
voiced_frames.append(f)
ring_buffer.clear()
else:
# We're in the TRIGGERED state, so collect the audio data
# and add it to the ring buffer.
voiced_frames.append(frame)
ring_buffer.append((frame, is_speech))
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
# If more than 90% of the frames in the ring buffer are
# unvoiced, then enter NOTTRIGGERED and yield whatever
# audio we've collected.
if num_unvoiced > 0.9 * ring_buffer.maxlen:
triggered = False
yield (start, frame.timestamp + frame.duration)
ring_buffer.clear()
voiced_frames = []
# If we have any leftover voiced audio when we run out of input,
# yield it.
if voiced_frames:
yield (start, frame.timestamp + frame.duration)
def VAD_chunk(aggressiveness, path):
audio, byte_audio = read_wave(path, sr)
vad = webrtcvad.Vad(int(aggressiveness))
frames = frame_generator(20, byte_audio, sr)
frames = list(frames)
times = vad_collector(sr, 20, 200, vad, frames)
speech_times = []
speech_segs = []
for i, time in enumerate(times):
start = np.round(time[0],decimals=2)
end = np.round(time[1],decimals=2)
j = start
while j + .4 < end:
end_j = np.round(j+.4,decimals=2)
speech_times.append((j, end_j))
speech_segs.append(audio[int(j*sr):int(end_j*sr)])
j = end_j
else:
speech_times.append((j, end))
speech_segs.append(audio[int(j*sr):int(end*sr)])
return speech_times, speech_segs
# also based on code from PyTorch Speaker Verification
# wav2vec2's max duration is 40 seconds, using 39 by default
# to be a little safer
def concat_both(times, segs, max_duration=39.0):
"""Concatenate continuous times and their segments"""
import numpy as np
absolute_maximum=40.0
if max_duration > absolute_maximum:
raise Exception('`max_duration` {:.2f} larger than kernel size (40 seconds)'.format(max_duration))
# we take 0.0 to mean "don't concatenate"
do_concat = (max_duration != 0.0)
concat_seg = []
concat_times = []
seg_concat = segs[0]
time_concat = times[0]
for i in range(0, len(times)-1):
can_concat = (times[i+1][1] - time_concat[0]) < max_duration
if time_concat[1] == times[i+1][0] and do_concat and can_concat:
seg_concat = np.concatenate((seg_concat, segs[i+1]))
time_concat = (time_concat[0], times[i+1][1])
else:
concat_seg.append(seg_concat)
seg_concat = segs[i+1]
concat_times.append(time_concat)
time_concat = times[i+1]
else:
concat_seg.append(seg_concat)
concat_times.append(time_concat)
return concat_times, concat_seg
def make_dataset(concat_times, concat_segs):
starts = [s[0] for s in concat_times]
ends = [s[1] for s in concat_times]
return {'start': starts,
'end': ends,
'speech': concat_segs}
from datasets import Dataset
def vad_to_dataset(path, max_duration):
t,s = VAD_chunk(3, path)
if max_duration > 0.0:
ct, cs = concat_both(t, s, max_duration)
dset = make_dataset(ct, cs)
else:
dset = make_dataset(t, s)
return Dataset.from_dict(dset)
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
# load model and tokenizer
processor = Wav2Vec2Processor.from_pretrained("mbien/wav2vec2-large-xlsr-polish")
model = Wav2Vec2ForCTC.from_pretrained("mbien/wav2vec2-large-xlsr-polish")
model.to("cuda")
def speech_file_to_array_fn(batch):
import torchaudio
speech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = speech_array[0].numpy()
batch["sampling_rate"] = sampling_rate
batch["target_text"] = batch["sentence"]
return batch
def evaluate(batch):
import torch
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["pred_strings"] = processor.batch_decode(pred_ids)
return batch
def process_wave(filename, duration):
import json
dataset = vad_to_dataset(filename, duration)
result = dataset.map(evaluate, batched=True, batch_size=16)
speechless = result.remove_columns(['speech'])
d=speechless.to_dict()
tlog = list()
for i in range(0, len(d['end']) - 1):
out = dict()
out['start'] = d['start'][i]
out['end'] = d['end'][i]
out['transcript'] = d['pred_strings'][i]
tlog.append(out)
with open('{}.tlog'.format(filename), 'w') as outfile:
json.dump(tlog, outfile)
import glob
for f in glob.glob('./*.wav'):
print(f)
process_wave(f, 10.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment