-
-
Save wesbz/6a2a33f751f6dd3117c10369f786a46d to your computer and use it in GitHub Desktop.
Way to run harvard-edge/multilingual_kws code in streaming.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import fire | |
import time | |
import pyaudio | |
import tensorflow as tf | |
import numpy as np | |
from multilingual_kws.embedding import input_data | |
from multilingual_kws.embedding import batch_streaming_analysis as sa | |
from multilingual_kws.embedding.single_target_recognize_commands import ( | |
SingleTargetRecognizeCommands, | |
RecognizeResult, | |
) | |
# Don't use GPU, it's much faster on CPU. | |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
SAMPLING_RATE = 16000 | |
FORMAT = pyaudio.paInt16 | |
def inference( | |
keyword: str, | |
modelpath: str, | |
detection_threshold: float = 0.9 | |
): | |
""" | |
Runs inference on a streaming audio file. Example invocation: | |
$ python -m embedding.run_inference --keyword mask --modelpath mask_model --wav mask_radio.wav | |
Args | |
keyword: target keyword for few-shot KWS (pass in as [word1, word2, word3]) | |
modelpath: comma-demlimited list of paths to finetuned few-shot models | |
detection_threshold: confidence threshold for inference (default=0.9) | |
""" | |
assert os.path.exists(modelpath), f"{modelpath} inference model not found" | |
assert (0< detection_threshold < 1), "detection_threshold must be between 0 and 1." | |
print(f"Performing inference using detection threshold {detection_threshold}") | |
flags = sa.StreamFlags( | |
wav=None, | |
ground_truth=None, | |
target_keyword=keyword, | |
detection_thresholds=[detection_threshold], | |
average_window_duration_ms=200, | |
suppression_ms=500, | |
clip_stride_ms=20, | |
) | |
start_load_model = time.time() | |
# Load the TFLite model and allocate tensors. | |
interpreter = tf.lite.Interpreter(model_path=modelpath) | |
interpreter.allocate_tensors() | |
# Get input and output tensors details. | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
print(f"loaded model in {time.time() - start_load_model} seconds") | |
model_settings = input_data.standard_microspeech_model_settings(label_count=3) | |
# The next two components are used to compute the average score. | |
recognize_element = RecognizeResult() | |
recognize_commands = SingleTargetRecognizeCommands( | |
labels=flags.labels(), | |
average_window_duration_ms=flags.average_window_duration_ms, | |
detection_threshold=detection_threshold, | |
suppression_ms=flags.suppression_ms, | |
minimum_count=flags.minimum_count, | |
target_id=2, | |
) | |
class AudioHandler(object): | |
def __init__(self, rate=SAMPLING_RATE, buffer_size=SAMPLING_RATE//10, format_=FORMAT, channel=1) -> None: | |
super().__init__() | |
self.format = format_ | |
self.channel = channel | |
self.rate = rate | |
self.buffer_size = buffer_size | |
self.spectrogram_buffer = np.zeros(( | |
model_settings["spectrogram_length"], | |
model_settings["fingerprint_width"] | |
)) | |
self.audio_buffer = np.zeros((self.rate, channel)) | |
self.win_size = model_settings["window_size_samples"] | |
self.win_stride = model_settings["window_stride_samples"] | |
self.p = None | |
self.stream = None | |
self.current_time_ms = 0 | |
# This is used to compute average duration of the callback function. | |
self.time = time.time() | |
self.time_samples = 10 | |
self.time_avg = [0.0 for _ in range(self.time_samples)] | |
def start(self): | |
self.p = pyaudio.PyAudio() | |
self.stream = self.p.open( | |
rate=self.rate, | |
channels=self.channel, | |
format=self.format, | |
input=True, | |
stream_callback=self.callback, | |
frames_per_buffer=self.buffer_size | |
) | |
def stop(self): | |
self.stream.close() | |
self.p.terminate() | |
def callback(self, in_data, frame_count, time_info, flag): | |
# Compute callback average duration | |
end = time.time() | |
self.time_avg.pop(0) | |
self.time_avg.append(end - self.time) | |
print(f"t_callback = {1000*sum(self.time_avg)/self.time_samples:.2f}ms", end="\t") | |
self.time = end | |
# Load captured audio data | |
audio = np.frombuffer(in_data, dtype=np.int16) | |
# Here, the normalization factor is to adjust as it has an impact on the inference. Should be 2**15 or 2**16. | |
audio = (audio / 2**16).astype(np.float32) | |
# Concatenate the audio buffer with the new audio data | |
recycle_idx = (len(self.audio_buffer)-(self.win_size-self.win_stride))//self.win_stride * self.win_stride | |
self.audio_buffer = np.concatenate( | |
[ | |
self.audio_buffer[recycle_idx:], | |
audio.reshape(-1, 1) | |
], | |
axis=0 | |
) | |
# Compute the next spectrogram chunk | |
spectrogram = input_data.to_micro_spectrogram(model_settings, self.audio_buffer) | |
# Concatenate with the previous spectrogram chunks to form a full spectrogram | |
self.spectrogram_buffer = np.concatenate([self.spectrogram_buffer, spectrogram], axis=0)[-model_settings["spectrogram_length"]:, :] | |
# This puts the spectrogram as model's input | |
interpreter.set_tensor(input_details[0]["index"], tf.cast(self.spectrogram_buffer[np.newaxis,:,:,np.newaxis], dtype=tf.float32)) | |
# This does the inference | |
interpreter.invoke() | |
inferences = interpreter.get_tensor(output_details[0]["index"]) | |
# How long in ms have we been processing | |
self.current_time_ms = self.current_time_ms + frame_count * 1000 / self.rate | |
# In case you want to print the inference results. | |
# print(self.current_time_ms, inferences[0], end="\t") | |
recognize_commands.process_latest_result( | |
inferences[0], self.current_time_ms, recognize_element | |
) | |
if ( | |
recognize_element.is_new_command | |
and recognize_element.found_command != "_silence_" | |
): | |
print( | |
"\033[92mDETECTED\033[0m" | |
) | |
else: | |
print("\033[91mNOT DETECTED\033[0m") | |
return in_data, pyaudio.paContinue | |
ah = AudioHandler( | |
rate=SAMPLING_RATE, | |
buffer_size=SAMPLING_RATE//20, | |
format_=FORMAT, | |
channel=1 | |
) | |
ah.start() | |
print("\tRecording") | |
# On appuie sur ENTREE pour arrêter le programme | |
input(f"Enter to stop\n") | |
print("\tStopped") | |
ah.stop() | |
tf.keras.backend.clear_session() | |
return | |
if __name__ == "__main__": | |
fire.Fire(dict(inference=inference)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment