This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fun performInference(context: Context) { | |
val startTime = System.currentTimeMillis() | |
val outputs = Array(1) { | |
FloatArray(11) | |
} | |
// Get the audio file. | |
val fileInputstream: InputStream = context.assets.open("down.wav") | |
try { | |
val byteIOArray = IOUtils.toByteArray(fileInputstream) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Init the TnsorFlow Lite Interpreter | |
fun initInterpreter(context: Context) { | |
interpreterAudioClassification = getInterpreter(context, "speech_commands_model.tflite") | |
} | |
// Get the Interpreter. | |
@Throws(IOException::class) | |
private fun getInterpreter( | |
context: Context, | |
modelName: String |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# open the .wav file | |
with open('down.wav', 'rb') as wav_file: | |
wav_data = wav_file.read() | |
# load TFLite model and set params | |
mod_path = '/content/speech_commands_model.tflite' | |
interpreter = tf.lite.Interpreter(model_path=mod_path) | |
interpreter.allocate_tensors() | |
input_index = interpreter.get_input_details()[0]["index"] | |
output_index = interpreter.get_output_details()[0]["index"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Convert the model | |
converter = tf.lite.TFLiteConverter.from_saved_model("/content/saved_model") # path to the SavedModel directory | |
converter.target_spec.supported_ops = [ | |
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. | |
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. | |
] | |
tflite_model = converter.convert() | |
# Save the model. | |
with open('speech_commands_model.tflite', 'wb') as f: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### Convert TF1 frozen Graph to TF1 SavedModel. | |
# Load the graph as a v1.GraphDef | |
import pathlib | |
gdef = tf.compat.v1.GraphDef() | |
gdef.ParseFromString(pathlib.Path(GRAPH_DEF_MODEL_PATH).read_bytes()) | |
# Convert the GraphDef to a tf.Graph | |
with tf.Graph().as_default() as g: | |
tf.graph_util.import_graph_def(gdef, name="") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# print all layers' names | |
def printTensors(pb_file): | |
# read pb into graph_def | |
with tf.gfile.GFile(pb_file, "rb") as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
# import graph_def | |
with tf.Graph().as_default() as graph: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
GRAPH_DEF_MODEL_PATH = '/content/speech_commands_files/conv_actions_frozen.pb' | |
print("TF1 frozen GraphDef path: ", GRAPH_DEF_MODEL_PATH) | |
# view input and output layers' names | |
import tensorflow.compat.v1 as tf | |
tf.disable_v2_behavior() | |
gf = tf.GraphDef() | |
m_file = open(GRAPH_DEF_MODEL_PATH,'rb') | |
gf.ParseFromString(m_file.read()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Here we load again the model as we are going to build it under strategy.scope() | |
def build_model(max_len=128): | |
input_word_ids = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_ids") | |
input_mask = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="attention_mask") | |
segment_ids = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="token_type_ids") | |
bert_results = hub.KerasLayer(TFAutoModel.from_pretrained("nlpaueb/bert-base-greek-uncased-v1"), trainable=True, name='BERT_encoder')([input_word_ids, input_mask, segment_ids]) | |
#output = bert_results["last_hidden_state"][:, 0, :] | |
output = bert_results["pooler_output"][:,:] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create 3 inputs for the model | |
def bert_encode(texts, tokenizer, max_len=128): | |
all_tokens = [] | |
all_masks = [] | |
all_segments = [] | |
for text in texts: | |
text_preprocessed = tokenizer(text) | |
tokenized_text = text_preprocessed["input_ids"] | |
#print(tokenized_text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
MIXED_PRECISION = False | |
if MIXED_PRECISION: | |
if tpu: | |
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') | |
else: # | |
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') | |
tf.config.optimizer.set_jit(True) # XLA compilation | |
tf.keras.mixed_precision.experimental.set_policy(policy) | |
print('Mixed precision enabled') |
NewerOlder