Skip to content

Instantly share code, notes, and snippets.

@antonyharfield
Created June 13, 2020 13:38
Show Gist options
  • Save antonyharfield/8398ad6fd2400e25dc8976fb2918bf1e to your computer and use it in GitHub Desktop.
Save antonyharfield/8398ad6fd2400e25dc8976fb2918bf1e to your computer and use it in GitHub Desktop.
YAMNet to TFLite conversion
import tensorflow as tf
from tensorflow.keras import Model, layers
import features as features_lib
import features_tflite as features_tflite_lib
import params
from yamnet import yamnet
def yamnet_frames_tflite_model(feature_params):
"""Defines the YAMNet model suitable for tflite conversion."""
num_samples = int(round(params.SAMPLE_RATE * 0.975))
waveform = layers.Input(batch_shape=(1, num_samples))
spectrogram = features_tflite_lib.waveform_to_log_mel_spectrogram(
tf.squeeze(waveform, axis=0), feature_params)
patches = features_lib.spectrogram_to_patches(spectrogram, feature_params)
predictions = yamnet(patches)
frames_model = Model(name='yamnet_frames',
inputs=waveform, outputs=[predictions, spectrogram])
return frames_model
def main():
# Load the model and weights
model = yamnet_frames_tflite_model(params)
model.load_weights('yamnet.h5')
# Convert the model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("yamnet.tflite", "wb").write(tflite_model)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment