Skip to content

Instantly share code, notes, and snippets.

@petered
Created January 9, 2023 00:49
Show Gist options
  • Save petered/6d085852f5393c69f48893fa0c2f5220 to your computer and use it in GitHub Desktop.
Save petered/6d085852f5393c69f48893fa0c2f5220 to your computer and use it in GitHub Desktop.
How do I to save a stateful TFLite model where shape of state depends on input?
import shutil
import tempfile
from dataclasses import dataclass
from typing import Optional, Callable, Any, Mapping
import os
import numpy as np
import tensorflow as tf
def save_signatures_to_tflite_model(
concrete_function_dict: Mapping[str, Callable],
path: str,
parent_object: Any,
allow_custom_ops=False,
):
tempdir = tempfile.mkdtemp()
try:
saved_model_dir = os.path.expanduser(tempdir)
tf.saved_model.save(obj=parent_object, export_dir=saved_model_dir, signatures=concrete_function_dict)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS] # enable TensorFlow Lite ops.]
converter.experimental_enable_resource_variables = True
converter.allow_custom_ops = allow_custom_ops
serialized_model = converter.convert()
with open(path, 'wb') as f:
f.write(serialized_model)
finally:
if os.path.isdir(tempdir):
shutil.rmtree(tempdir)
def load_tflite_model_func(path: str) -> Callable:
interpreter = tf.lite.Interpreter(model_path=os.path.expanduser(path))
inputs = interpreter.get_input_details()
interpreter.allocate_tensors()
def model_func(*args):
assert len(inputs) == len(args)
for inp, a in zip(inputs, args):
interpreter.set_tensor(inp['index'], a)
interpreter.invoke()
output_details = interpreter.get_output_details()
if len(output_details) == 1: # Yes yes this is bad but we lose information about whether output is 1-tuple or scaler
return interpreter.get_tensor(output_details[0]['index'])
else:
return [interpreter.get_tensor(o['index']) for o in output_details]
return model_func
@dataclass
class TimeDelta(tf.Module):
_last_val: Optional[tf.Tensor] = None
def compute_delta(self, arr: tf.Tensor):
if self._last_val is None:
self._last_val = tf.Variable(tf.zeros(tf.shape(arr)))
delta = arr-self._last_val
self._last_val.assign(arr)
return delta
def test_save_delta():
compile_time_shape = 30, 40
tflite_model_file_path = tempfile.mktemp()
delta = TimeDelta()
save_signatures_to_tflite_model(
{'delta': tf.function(delta.compute_delta, input_signature=[tf.TensorSpec(shape=compile_time_shape)])},
path=tflite_model_file_path,
parent_object=delta
)
# Load the model and run test inputs
# HOW CAN I RESIZE THE STATE VARIABLE TO MATCH THE RUNTIME SHAPE?
func = load_tflite_model_func(tflite_model_file_path)
# runtime_shape = compile_time_shape # If I do this, it works fine
runtime_shape = 60, 80
rng = np.random.RandomState(1234)
ims = [rng.randn(*runtime_shape).astype(np.float32) for _ in range(3)]
assert np.allclose(func(ims[0]), ims[0])
assert np.allclose(func(ims[1]), ims[1]-ims[0])
assert np.allclose(func(ims[2]), ims[2]-ims[1])
if __name__ == "__main__":
test_save_delta()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment