Skip to content

Instantly share code, notes, and snippets.

@shadiakiki1986
Last active April 25, 2024 00:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shadiakiki1986/76bf09894c6dc7a5cea8a9614069181a to your computer and use it in GitHub Desktop.
Save shadiakiki1986/76bf09894c6dc7a5cea8a9614069181a to your computer and use it in GitHub Desktop.
# Forked function from BirdNET-Analyzer/main.py function load_model and predict
# https://github.com/kahst/BirdNET-Analyzer/blob/main/config.py
# https://github.com/kahst/BirdNET-Analyzer/blob/main/model.py
# https://github.com/kahst/BirdNET-Analyzer/blob/main/analyze.py
try:
import tflite_runtime.interpreter as tflite
except ModuleNotFoundError:
from tensorflow import lite as tflite
from multiprocessing import cpu_count
import pandas as pd
import tensorflow as tf
import numpy as np
from pathlib import Path
BASEDIR = (
Path(__file__).parent
#Path("/kaggle/input/birdnet-analyzer/tflite/birdnet_global_6k_v2.4_model_fp32-1/2/")
)
# class ModelBirdnetAnalyzer:
class Model:
def __init__(self, class_output=True):
# Load TFLite model and allocate tensors.
# model_path from: https://github.com/kahst/BirdNET-Analyzer/blob/main/config.py#L18C20-L18C77
INTERPRETER = tflite.Interpreter(model_path=
str(
#Path("BirdNET-Analyzer/checkpoints/V2.4")
BASEDIR
/"BirdNET_GLOBAL_6K_V2.4_Model_FP32.tflite"
), num_threads=cpu_count())
INTERPRETER.allocate_tensors()
# Get input and output tensors.
input_details = INTERPRETER.get_input_details()
output_details = INTERPRETER.get_output_details()
# Get input tensor index
INPUT_LAYER_INDEX = input_details[0]["index"]
# Get classification output or feature embeddings
if class_output:
OUTPUT_LAYER_INDEX = output_details[0]["index"]
else:
OUTPUT_LAYER_INDEX = output_details[0]["index"] - 1
self.INTERPRETER = INTERPRETER
self.INPUT_LAYER_INDEX = INPUT_LAYER_INDEX
self.OUTPUT_LAYER_INDEX = OUTPUT_LAYER_INDEX
with open(
#Path("BirdNET-Analyzer/checkpoints/V2.4")
BASEDIR
/"BirdNET_GLOBAL_6K_V2.4_Labels.txt", "r") as f:
self.labels = [x.strip() for x in f.readlines()]
def predict(self, data, samplerate, strict=True):
if strict:
assert samplerate==48_000, "birdnet assumes 48kHz?"
assert len(data.shape)==2, "data should be 2d array"
assert data.shape[1]==samplerate*3, "data should be 3 second windows"
# Reshape input tensor
self.INTERPRETER.resize_tensor_input(self.INPUT_LAYER_INDEX, [len(data), *data[0].shape])
self.INTERPRETER.allocate_tensors()
# Make a prediction (Audio only for now)
self.INTERPRETER.set_tensor(self.INPUT_LAYER_INDEX, np.array(data, dtype="float32"))
self.INTERPRETER.invoke()
prediction = self.INTERPRETER.get_tensor(self.OUTPUT_LAYER_INDEX)
prediction = pd.DataFrame(prediction, columns=self.labels)
return prediction
def predict_proba(self, *args, **kwargs):
prediction = self.predict(*args, **kwargs)
return (
prediction
# logits to probabilities
.pipe(lambda df: pd.DataFrame(tf.nn.softmax(df), index=df.index, columns=df.columns))
#.sum(axis=1) # all 1.
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment