Last active
April 3, 2021 00:30
-
-
Save jimregan/e2b7666c5c6397d7939b0ca97354f314 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# VAD wrapper is taken from PyTorch Speaker Verification: | |
# https://github.com/HarryVolek/PyTorch_Speaker_Verification | |
# Copyright (c) 2019, HarryVolek | |
# License: BSD-3-Clause | |
# based on https://github.com/wiseman/py-webrtcvad/blob/master/example.py | |
# Copyright (c) 2016 John Wiseman | |
# License: MIT | |
# Additions Copyright (c) 2021, Jim O'Regan | |
# License: MIT | |
import collections | |
import contextlib | |
import numpy as np | |
import sys | |
import librosa | |
import wave | |
import webrtcvad | |
#from hparam import hparam as hp | |
sr = 16000 | |
def read_wave(path, sr): | |
"""Reads a .wav file. | |
Takes the path, and returns (PCM audio data, sample rate). | |
Assumes sample width == 2 | |
""" | |
with contextlib.closing(wave.open(path, 'rb')) as wf: | |
num_channels = wf.getnchannels() | |
assert num_channels == 1 | |
sample_width = wf.getsampwidth() | |
assert sample_width == 2 | |
sample_rate = wf.getframerate() | |
assert sample_rate in (8000, 16000, 32000, 48000) | |
pcm_data = wf.readframes(wf.getnframes()) | |
data, _ = librosa.load(path, sr) | |
assert len(data.shape) == 1 | |
assert sr in (8000, 16000, 32000, 48000) | |
return data, pcm_data | |
class Frame(object): | |
"""Represents a "frame" of audio data.""" | |
def __init__(self, bytes, timestamp, duration): | |
self.bytes = bytes | |
self.timestamp = timestamp | |
self.duration = duration | |
def frame_generator(frame_duration_ms, audio, sample_rate): | |
"""Generates audio frames from PCM audio data. | |
Takes the desired frame duration in milliseconds, the PCM data, and | |
the sample rate. | |
Yields Frames of the requested duration. | |
""" | |
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2) | |
offset = 0 | |
timestamp = 0.0 | |
duration = (float(n) / sample_rate) / 2.0 | |
while offset + n < len(audio): | |
yield Frame(audio[offset:offset + n], timestamp, duration) | |
timestamp += duration | |
offset += n | |
def vad_collector(sample_rate, frame_duration_ms, | |
padding_duration_ms, vad, frames): | |
"""Filters out non-voiced audio frames. | |
Given a webrtcvad.Vad and a source of audio frames, yields only | |
the voiced audio. | |
Uses a padded, sliding window algorithm over the audio frames. | |
When more than 90% of the frames in the window are voiced (as | |
reported by the VAD), the collector triggers and begins yielding | |
audio frames. Then the collector waits until 90% of the frames in | |
the window are unvoiced to detrigger. | |
The window is padded at the front and back to provide a small | |
amount of silence or the beginnings/endings of speech around the | |
voiced frames. | |
Arguments: | |
sample_rate - The audio sample rate, in Hz. | |
frame_duration_ms - The frame duration in milliseconds. | |
padding_duration_ms - The amount to pad the window, in milliseconds. | |
vad - An instance of webrtcvad.Vad. | |
frames - a source of audio frames (sequence or generator). | |
Returns: A generator that yields PCM audio data. | |
""" | |
num_padding_frames = int(padding_duration_ms / frame_duration_ms) | |
# We use a deque for our sliding window/ring buffer. | |
ring_buffer = collections.deque(maxlen=num_padding_frames) | |
# We have two states: TRIGGERED and NOTTRIGGERED. We start in the | |
# NOTTRIGGERED state. | |
triggered = False | |
voiced_frames = [] | |
for frame in frames: | |
is_speech = vad.is_speech(frame.bytes, sample_rate) | |
if not triggered: | |
ring_buffer.append((frame, is_speech)) | |
num_voiced = len([f for f, speech in ring_buffer if speech]) | |
# If we're NOTTRIGGERED and more than 90% of the frames in | |
# the ring buffer are voiced frames, then enter the | |
# TRIGGERED state. | |
if num_voiced > 0.9 * ring_buffer.maxlen: | |
triggered = True | |
start = ring_buffer[0][0].timestamp | |
# We want to yield all the audio we see from now until | |
# we are NOTTRIGGERED, but we have to start with the | |
# audio that's already in the ring buffer. | |
for f, s in ring_buffer: | |
voiced_frames.append(f) | |
ring_buffer.clear() | |
else: | |
# We're in the TRIGGERED state, so collect the audio data | |
# and add it to the ring buffer. | |
voiced_frames.append(frame) | |
ring_buffer.append((frame, is_speech)) | |
num_unvoiced = len([f for f, speech in ring_buffer if not speech]) | |
# If more than 90% of the frames in the ring buffer are | |
# unvoiced, then enter NOTTRIGGERED and yield whatever | |
# audio we've collected. | |
if num_unvoiced > 0.9 * ring_buffer.maxlen: | |
triggered = False | |
yield (start, frame.timestamp + frame.duration) | |
ring_buffer.clear() | |
voiced_frames = [] | |
# If we have any leftover voiced audio when we run out of input, | |
# yield it. | |
if voiced_frames: | |
yield (start, frame.timestamp + frame.duration) | |
def VAD_chunk(aggressiveness, path): | |
audio, byte_audio = read_wave(path, sr) | |
vad = webrtcvad.Vad(int(aggressiveness)) | |
frames = frame_generator(20, byte_audio, sr) | |
frames = list(frames) | |
times = vad_collector(sr, 20, 200, vad, frames) | |
speech_times = [] | |
speech_segs = [] | |
for i, time in enumerate(times): | |
start = np.round(time[0],decimals=2) | |
end = np.round(time[1],decimals=2) | |
j = start | |
while j + .4 < end: | |
end_j = np.round(j+.4,decimals=2) | |
speech_times.append((j, end_j)) | |
speech_segs.append(audio[int(j*sr):int(end_j*sr)]) | |
j = end_j | |
else: | |
speech_times.append((j, end)) | |
speech_segs.append(audio[int(j*sr):int(end*sr)]) | |
return speech_times, speech_segs | |
# also based on code from PyTorch Speaker Verification | |
# wav2vec2's max duration is 40 seconds, using 39 by default | |
# to be a little safer | |
def concat_both(times, segs, max_duration=39.0): | |
"""Concatenate continuous times and their segments""" | |
import numpy as np | |
absolute_maximum=40.0 | |
if max_duration > absolute_maximum: | |
raise Exception('`max_duration` {:.2f} larger than kernel size (40 seconds)'.format(max_duration)) | |
# we take 0.0 to mean "don't concatenate" | |
do_concat = (max_duration != 0.0) | |
concat_seg = [] | |
concat_times = [] | |
seg_concat = segs[0] | |
time_concat = times[0] | |
for i in range(0, len(times)-1): | |
can_concat = (times[i+1][1] - time_concat[0]) < max_duration | |
if time_concat[1] == times[i+1][0] and do_concat and can_concat: | |
seg_concat = np.concatenate((seg_concat, segs[i+1])) | |
time_concat = (time_concat[0], times[i+1][1]) | |
else: | |
concat_seg.append(seg_concat) | |
seg_concat = segs[i+1] | |
concat_times.append(time_concat) | |
time_concat = times[i+1] | |
else: | |
concat_seg.append(seg_concat) | |
concat_times.append(time_concat) | |
return concat_times, concat_seg | |
def make_dataset(concat_times, concat_segs): | |
starts = [s[0] for s in concat_times] | |
ends = [s[1] for s in concat_times] | |
return {'start': starts, | |
'end': ends, | |
'speech': concat_segs} | |
from datasets import Dataset | |
def vad_to_dataset(path, max_duration): | |
t,s = VAD_chunk(3, path) | |
if max_duration > 0.0: | |
ct, cs = concat_both(t, s, max_duration) | |
dset = make_dataset(ct, cs) | |
else: | |
dset = make_dataset(t, s) | |
return Dataset.from_dict(dset) | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
# load model and tokenizer | |
processor = Wav2Vec2Processor.from_pretrained("mbien/wav2vec2-large-xlsr-polish") | |
model = Wav2Vec2ForCTC.from_pretrained("mbien/wav2vec2-large-xlsr-polish") | |
model.to("cuda") | |
def speech_file_to_array_fn(batch): | |
import torchaudio | |
speech_array, sampling_rate = torchaudio.load(batch["path"]) | |
batch["speech"] = speech_array[0].numpy() | |
batch["sampling_rate"] = sampling_rate | |
batch["target_text"] = batch["sentence"] | |
return batch | |
def evaluate(batch): | |
import torch | |
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits | |
pred_ids = torch.argmax(logits, dim=-1) | |
batch["pred_strings"] = processor.batch_decode(pred_ids) | |
return batch | |
def process_wave(filename, duration): | |
import json | |
dataset = vad_to_dataset(filename, duration) | |
result = dataset.map(evaluate, batched=True, batch_size=16) | |
speechless = result.remove_columns(['speech']) | |
d=speechless.to_dict() | |
tlog = list() | |
for i in range(0, len(d['end']) - 1): | |
out = dict() | |
out['start'] = d['start'][i] | |
out['end'] = d['end'][i] | |
out['transcript'] = d['pred_strings'][i] | |
tlog.append(out) | |
with open('{}.tlog'.format(filename), 'w') as outfile: | |
json.dump(tlog, outfile) | |
import glob | |
for f in glob.glob('./*.wav'): | |
print(f) | |
process_wave(f, 10.0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment