Skip to content

Instantly share code, notes, and snippets.

@thkim-cochl
Last active March 21, 2019 10:37
Show Gist options
  • Save thkim-cochl/2eda391d89f105cc24b0dc83ccf9475b to your computer and use it in GitHub Desktop.
Save thkim-cochl/2eda391d89f105cc24b0dc83ccf9475b to your computer and use it in GitHub Desktop.
TF lite bath model test
import tensorflow as tf
import numpy as np
from keras.models import load_model
import time
event_list = ['bathhub', 'shower', 'sink', 'toiletflush', 'others']
input_shape = (1, 1, 110250)
def postprocessing(response):
response = np.reshape(response, (len(event_list)))
res_list = [round(r,3) for r in response.tolist()]
result = dict(zip(event_list, res_list))
return result
def predict_keras(data, fs=22050):
data = np.reshape(data, input_shape)
# Load model
model = load_model('m-2.h5')
model.load_weights('weights_e570.h5')
# Predict
pred = model.predict(data)
res = postprocessing(pred)
return res
def predict_tflite(data, fs=22050):
data = np.reshape(data, input_shape)
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="model2-7.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Predict
interpreter.set_tensor(input_details[0]['index'], data)
interpreter.invoke()
pred = interpreter.get_tensor(output_details[0]['index'])
res = postprocessing(pred)
return res
if __name__ == "__main__":
task = 'event'
input_data = np.load('input.npy')
start_time = time.time()
pred = predict_keras(input_data)
end_time = time.time()
print("< Keras Prediction >\n{}\nExecution Time: {:.3} second".format(pred, end_time - start_time))
start_time = time.time()
pred = predict_tflite(input_data)
end_time = time.time()
print("< TF lite Prediction >\n{}\nExecution Time: {:.3} second".format(pred, end_time - start_time))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment