Skip to content

Instantly share code, notes, and snippets.

@vmarkovtsev
Last active November 27, 2019 21:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vmarkovtsev/46ee236d268d9ee48c04d74f47094100 to your computer and use it in GitHub Desktop.
Save vmarkovtsev/46ee236d268d9ee48c04d74f47094100 to your computer and use it in GitHub Desktop.
import logging
from pathlib import Path
def create_motion_blur_func(images_shape: Tuple[int], dim: int, angle: float):
kernel = create_motion_blur_kernel(dim, angle)
# This is new: input signature definition
@tf.function(input_signature=[tf.TensorSpec(images_shape, tf.float32)])
def motion_blur_func(images):
return tf.nn.depthwise_conv2d(images, kernel, strides=[1] * 4, padding="SAME")
return motion_blur_func
def create_motion_blur_func_lite(images_shape: Tuple[int], dim: int, angle: float):
name = "motion_blur_%s_%d_%.2f_f32" % ("_".join(map(str, images_shape)), dim, angle)
ctor = lambda: create_motion_blur_func(images_shape, dim, angle)
return create_func_lite(images_shape, ctor, name)
def create_func_lite(images_shape: Tuple[int], ctor: callable, name: str):
log = logging.getLogger(name)
log.setLevel(logging.INFO)
fn = "%s.tflite" % name
if not Path(fn).exists():
log.info("Creating the regular TensorFlow kernel")
func = ctor()
generate_lite_model(log, images_shape, func, name)
log.info("Loading the Lite model")
interpreter = tf.lite.Interpreter(model_path=fn)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
tensor_in = interpreter.tensor(input_details[0]["index"])
tensor_out = interpreter.tensor(output_details[0]["index"])
def invoke(images):
tensor_in()[:] = images
interpreter.invoke()
return tensor_out().copy()
return invoke
def generate_lite_model(log, images_shape, func, name):
log.info("Generating the TensorFlow Lite model")
converter = tf.lite.TFLiteConverter.from_concrete_functions([func.get_concrete_function()])
tflite_model = converter.convert()
fn = "%s.tflite" % name
with open(fn, "wb") as fout:
fout.write(tflite_model)
motion_blur = create_motion_blur_func_lite(images.shape, 25, (90/180)*np.pi)
save_image(motion_blur(images), "result.jpg")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment