Skip to content

Instantly share code, notes, and snippets.

@wesbz
Created June 1, 2022 23:06
Show Gist options
  • Save wesbz/6a2a33f751f6dd3117c10369f786a46d to your computer and use it in GitHub Desktop.
Save wesbz/6a2a33f751f6dd3117c10369f786a46d to your computer and use it in GitHub Desktop.
Way to run harvard-edge/multilingual_kws code in streaming.
import os
import fire
import time
import pyaudio
import tensorflow as tf
import numpy as np
from multilingual_kws.embedding import input_data
from multilingual_kws.embedding import batch_streaming_analysis as sa
from multilingual_kws.embedding.single_target_recognize_commands import (
SingleTargetRecognizeCommands,
RecognizeResult,
)
# Don't use GPU, it's much faster on CPU.
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
SAMPLING_RATE = 16000
FORMAT = pyaudio.paInt16
def inference(
keyword: str,
modelpath: str,
detection_threshold: float = 0.9
):
"""
Runs inference on a streaming audio file. Example invocation:
$ python -m embedding.run_inference --keyword mask --modelpath mask_model --wav mask_radio.wav
Args
keyword: target keyword for few-shot KWS (pass in as [word1, word2, word3])
modelpath: comma-demlimited list of paths to finetuned few-shot models
detection_threshold: confidence threshold for inference (default=0.9)
"""
assert os.path.exists(modelpath), f"{modelpath} inference model not found"
assert (0< detection_threshold < 1), "detection_threshold must be between 0 and 1."
print(f"Performing inference using detection threshold {detection_threshold}")
flags = sa.StreamFlags(
wav=None,
ground_truth=None,
target_keyword=keyword,
detection_thresholds=[detection_threshold],
average_window_duration_ms=200,
suppression_ms=500,
clip_stride_ms=20,
)
start_load_model = time.time()
# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=modelpath)
interpreter.allocate_tensors()
# Get input and output tensors details.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(f"loaded model in {time.time() - start_load_model} seconds")
model_settings = input_data.standard_microspeech_model_settings(label_count=3)
# The next two components are used to compute the average score.
recognize_element = RecognizeResult()
recognize_commands = SingleTargetRecognizeCommands(
labels=flags.labels(),
average_window_duration_ms=flags.average_window_duration_ms,
detection_threshold=detection_threshold,
suppression_ms=flags.suppression_ms,
minimum_count=flags.minimum_count,
target_id=2,
)
class AudioHandler(object):
def __init__(self, rate=SAMPLING_RATE, buffer_size=SAMPLING_RATE//10, format_=FORMAT, channel=1) -> None:
super().__init__()
self.format = format_
self.channel = channel
self.rate = rate
self.buffer_size = buffer_size
self.spectrogram_buffer = np.zeros((
model_settings["spectrogram_length"],
model_settings["fingerprint_width"]
))
self.audio_buffer = np.zeros((self.rate, channel))
self.win_size = model_settings["window_size_samples"]
self.win_stride = model_settings["window_stride_samples"]
self.p = None
self.stream = None
self.current_time_ms = 0
# This is used to compute average duration of the callback function.
self.time = time.time()
self.time_samples = 10
self.time_avg = [0.0 for _ in range(self.time_samples)]
def start(self):
self.p = pyaudio.PyAudio()
self.stream = self.p.open(
rate=self.rate,
channels=self.channel,
format=self.format,
input=True,
stream_callback=self.callback,
frames_per_buffer=self.buffer_size
)
def stop(self):
self.stream.close()
self.p.terminate()
def callback(self, in_data, frame_count, time_info, flag):
# Compute callback average duration
end = time.time()
self.time_avg.pop(0)
self.time_avg.append(end - self.time)
print(f"t_callback = {1000*sum(self.time_avg)/self.time_samples:.2f}ms", end="\t")
self.time = end
# Load captured audio data
audio = np.frombuffer(in_data, dtype=np.int16)
# Here, the normalization factor is to adjust as it has an impact on the inference. Should be 2**15 or 2**16.
audio = (audio / 2**16).astype(np.float32)
# Concatenate the audio buffer with the new audio data
recycle_idx = (len(self.audio_buffer)-(self.win_size-self.win_stride))//self.win_stride * self.win_stride
self.audio_buffer = np.concatenate(
[
self.audio_buffer[recycle_idx:],
audio.reshape(-1, 1)
],
axis=0
)
# Compute the next spectrogram chunk
spectrogram = input_data.to_micro_spectrogram(model_settings, self.audio_buffer)
# Concatenate with the previous spectrogram chunks to form a full spectrogram
self.spectrogram_buffer = np.concatenate([self.spectrogram_buffer, spectrogram], axis=0)[-model_settings["spectrogram_length"]:, :]
# This puts the spectrogram as model's input
interpreter.set_tensor(input_details[0]["index"], tf.cast(self.spectrogram_buffer[np.newaxis,:,:,np.newaxis], dtype=tf.float32))
# This does the inference
interpreter.invoke()
inferences = interpreter.get_tensor(output_details[0]["index"])
# How long in ms have we been processing
self.current_time_ms = self.current_time_ms + frame_count * 1000 / self.rate
# In case you want to print the inference results.
# print(self.current_time_ms, inferences[0], end="\t")
recognize_commands.process_latest_result(
inferences[0], self.current_time_ms, recognize_element
)
if (
recognize_element.is_new_command
and recognize_element.found_command != "_silence_"
):
print(
"\033[92mDETECTED\033[0m"
)
else:
print("\033[91mNOT DETECTED\033[0m")
return in_data, pyaudio.paContinue
ah = AudioHandler(
rate=SAMPLING_RATE,
buffer_size=SAMPLING_RATE//20,
format_=FORMAT,
channel=1
)
ah.start()
print("\tRecording")
# On appuie sur ENTREE pour arrêter le programme
input(f"Enter to stop\n")
print("\tStopped")
ah.stop()
tf.keras.backend.clear_session()
return
if __name__ == "__main__":
fire.Fire(dict(inference=inference))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment