Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Last active November 10, 2021 14:25
Show Gist options
  • Save llandsmeer/c244b37a6a3a1ce0a237fc7282ebe957 to your computer and use it in GitHub Desktop.
Save llandsmeer/c244b37a6a3a1ce0a237fc7282ebe957 to your computer and use it in GitHub Desktop.
Tensorflow lite example of fibonacci
import numpy as np
import tensorflow as tf
import tflite_runtime.interpreter as tflite
# Fib function
@tf.function
def fibonacci(n):
a = 0
b = 1
i = 0
def cond(i, a, b):
return i < 10
def body(i, a, b):
return (i + 1, b, a + b)
i, a, b = tf.while_loop(
cond = cond,
body = body,
loop_vars = (i,a,b))
return b
# Convert to tflite
converter = tf.lite.TFLiteConverter.from_concrete_functions([
fibonacci.get_concrete_function(n=tf.TensorSpec((), tf.int32))])
tflite_model = converter.convert()
with open('converted_model.tflite', 'wb') as f:
f.write(tflite_model)
# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path='converted_model.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
# [{'name': 'n',
# 'index': 0,
# 'shape': array([], dtype=int32),
# 'shape_signature': array([], dtype=int32),
# 'dtype': numpy.int32,
# 'quantization': (0.0, 0),
# 'quantization_parameters': {'scales': array([], dtype=float32),
# 'zero_points': array([], dtype=int32),
# 'quantized_dimension': 0},
# 'sparsity_parameters': {}}]
output_details = interpreter.get_output_details()
# [{'name': 'Identity',
# 'index': 6,
# 'shape': array([], dtype=int32),
# 'shape_signature': array([], dtype=int32),
# 'dtype': numpy.int32,
# 'quantization': (0.0, 0),
# 'quantization_parameters': {'scales': array([], dtype=float32),
# 'zero_points': array([], dtype=int32),
# 'quantized_dimension': 0},
# 'sparsity_parameters': {}}]
def tflfib(n):
input_data = np.int32(n)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
return output_data
import numpy as np
import tensorflow as tf
@tf.function(input_signature=(tf.TensorSpec(shape=(), dtype=tf.int32),))
def fibonacci(n: np.int32):
a, b = 0, 1
for i in range(n):
a, b = b, a + b
return b
tflite_model = tf.lite.TFLiteConverter.from_concrete_functions([fibonacci.get_concrete_function()]).convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_idx = interpreter.get_input_details()[0]['index']
output_idx = interpreter.get_output_details()[0]['index']
def tflfib(n):
interpreter.set_tensor(input_idx, np.int32(n))
interpreter.invoke()
output_data = interpreter.get_tensor(output_idx)
return output_data
for i in range(10):
print(tflfib(i))
@llandsmeer
Copy link
Author

In [*]: for i in range(10):
    ...:     print(f'{i: 4d} {tflfib(i)}')
    ...:
   0 1
   1 1
   2 2
   3 3
   4 5
   5 8
   6 13
   7 21
   8 34
   9 55

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment