Last active
November 27, 2019 21:59
-
-
Save vmarkovtsev/46ee236d268d9ee48c04d74f47094100 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from pathlib import Path | |
def create_motion_blur_func(images_shape: Tuple[int], dim: int, angle: float): | |
kernel = create_motion_blur_kernel(dim, angle) | |
# This is new: input signature definition | |
@tf.function(input_signature=[tf.TensorSpec(images_shape, tf.float32)]) | |
def motion_blur_func(images): | |
return tf.nn.depthwise_conv2d(images, kernel, strides=[1] * 4, padding="SAME") | |
return motion_blur_func | |
def create_motion_blur_func_lite(images_shape: Tuple[int], dim: int, angle: float): | |
name = "motion_blur_%s_%d_%.2f_f32" % ("_".join(map(str, images_shape)), dim, angle) | |
ctor = lambda: create_motion_blur_func(images_shape, dim, angle) | |
return create_func_lite(images_shape, ctor, name) | |
def create_func_lite(images_shape: Tuple[int], ctor: callable, name: str): | |
log = logging.getLogger(name) | |
log.setLevel(logging.INFO) | |
fn = "%s.tflite" % name | |
if not Path(fn).exists(): | |
log.info("Creating the regular TensorFlow kernel") | |
func = ctor() | |
generate_lite_model(log, images_shape, func, name) | |
log.info("Loading the Lite model") | |
interpreter = tf.lite.Interpreter(model_path=fn) | |
interpreter.allocate_tensors() | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
tensor_in = interpreter.tensor(input_details[0]["index"]) | |
tensor_out = interpreter.tensor(output_details[0]["index"]) | |
def invoke(images): | |
tensor_in()[:] = images | |
interpreter.invoke() | |
return tensor_out().copy() | |
return invoke | |
def generate_lite_model(log, images_shape, func, name): | |
log.info("Generating the TensorFlow Lite model") | |
converter = tf.lite.TFLiteConverter.from_concrete_functions([func.get_concrete_function()]) | |
tflite_model = converter.convert() | |
fn = "%s.tflite" % name | |
with open(fn, "wb") as fout: | |
fout.write(tflite_model) | |
motion_blur = create_motion_blur_func_lite(images.shape, 25, (90/180)*np.pi) | |
save_image(motion_blur(images), "result.jpg") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment