Skip to content

Instantly share code, notes, and snippets.

@0187773933
Last active April 26, 2024 18:13
Show Gist options
  • Save 0187773933/0081d6303be18360b88e929e3f34877d to your computer and use it in GitHub Desktop.
Save 0187773933/0081d6303be18360b88e929e3f34877d to your computer and use it in GitHub Desktop.
Runs Google Media Pipe Yamnet Audio Classification on Microphone Audio
import sounddevice as sd
import numpy as np
import tensorflow as tf
import queue
from collections import defaultdict , deque
import time
# https://storage.googleapis.com/mediapipe-models/audio_classifier/yamnet/float32/latest/yamnet.tflite
# https://github.com/tensorflow/models/blob/master/research/audioset/yamnet/yamnet.py
# https://research.google.com/audioset/ontology/index.html
# https://storage.googleapis.com/mediapipe-tasks/audio_classifier/yamnet_label_list.txt
# https://github.com/tensorflow/models/blob/master/research/audioset/yamnet/params.py#L25
MODEL_PATH = "./yamnet.tflite"
LABEL_PATH = "./yamnet_label_list.txt"
SAMPLE_RATE = 16000
PATCH_WINDOW_SECONDS = 0.975
PATCH_HOP_SECONDS = ( PATCH_WINDOW_SECONDS / 2.0 )
WATCH_WINDOW_SECONDS = 30 # Time window to aggregate results
PRINT_WINDOW_TOTAL = 10
MINIMUM_THRESHOLD = 0.15
interpreter = tf.lite.Interpreter( model_path=MODEL_PATH )
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
q = queue.Queue()
def audio_callback( indata , frames , time , status ):
if status:
print( f"Error: {status}" )
q.put( indata.copy() )
def read_text( file_path ):
with open( file_path ) as f:
return f.read().splitlines()
def calculate_db( indata ):
rms = np.sqrt( np.mean( indata**2 ) )
# Avoid log of zero by adding a small value
rms = max( rms , 1e-10 )
# Convert to dB
db = 20 * np.log10( rms )
return db
def main():
model_labels = read_text(LABEL_PATH)
results = defaultdict(float)
past_results = deque()
last_time = time.time()
try:
with sd.InputStream(callback=audio_callback, dtype="int16", channels=2, samplerate=SAMPLE_RATE, blocksize=int(SAMPLE_RATE * PATCH_HOP_SECONDS)):
print("Starting audio stream...")
while True:
current_time = time.time()
data = np.concatenate([q.get() for _ in range(int(SAMPLE_RATE * PATCH_WINDOW_SECONDS / (SAMPLE_RATE * PATCH_HOP_SECONDS)))])
data = np.mean(data.astype(np.float32), axis=1) / np.iinfo(np.int16).max # Normalize and convert to mono
interpreter.set_tensor(input_details[0]["index"], data)
interpreter.invoke()
probabilities = interpreter.get_tensor(output_details[0]["index"]).flatten()
# Store results with timestamp
past_results.append((current_time, dict(zip(model_labels, probabilities))))
# Remove results older than WATCH_WINDOW_SECONDS
while past_results and past_results[0][0] < (current_time - WATCH_WINDOW_SECONDS):
old_time, old_results = past_results.popleft()
for label in old_results:
results[label] -= old_results[label]
# Add new results
for label, probability in zip(model_labels, probabilities):
results[label] += probability
# Apply minimum threshold filter only for display
filtered_results = {label: prob for label, prob in results.items() if prob > MINIMUM_THRESHOLD}
# Sort and display filtered results
sorted_results = sorted(filtered_results.items(), key=lambda item: item[1], reverse=True)
db_level = calculate_db(data) # Extras - DB Level
# Print updated results
print(f"\nLIVE : DB === {db_level}")
if sorted_results:
print(f"LIVE : TOP === {sorted_results[0][0]} : {sorted_results[0][1]}")
print(f"Last {WATCH_WINDOW_SECONDS} Seconds:")
for label, probability in sorted_results[:PRINT_WINDOW_TOTAL]:
print(f"\t{label}: {probability}")
except KeyboardInterrupt:
print("\nStopping...")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment