Last active
March 5, 2020 17:28
-
-
Save darien-schettler/fd5b25626e9eb5b1330cce670bf9cc17 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# version 2.1.0 | |
import tensorflow as tf | |
# version 1.18.1 | |
import numpy as np | |
# ######## DEFINE CUSTOM FUNCTION FOR TF LAMBDA LAYER ######## # | |
def resize_like(input_tensor, ref_tensor): | |
""" Resize an image tensor to the same size/shape as a reference image tensor | |
Args: | |
input_tensor : (image tensor) Input image tensor that will be resized | |
ref_tensor : (image tensor) Reference image tensor that we want to resize the input tensor to. | |
Returns: | |
reshaped tensor | |
""" | |
reshaped_tensor = tf.image.resize(images=input_tensor, | |
size=tf.shape(ref_tensor)[1:3], | |
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, | |
preserve_aspect_ratio=False, | |
antialias=False, | |
name=None) | |
return reshaped_tensor | |
# ############################################################# # | |
# ############ DEFINE MODEL USING TF.KERAS FN API ############ # | |
# INPUTS | |
model_input_1 = tf.keras.layers.Input(shape=(160,160,3)) | |
model_input_2 = tf.keras.layers.Input(shape=(160,160,3)) | |
# OUTPUTS | |
model_output_1 = tf.keras.layers.Conv2D(filters=64, | |
kernel_size=(1, 1), | |
use_bias=False, | |
kernel_initializer='he_normal', | |
name='conv_name_base')(model_input_1) | |
model_output_2 = tf.keras.layers.Lambda(function=resize_like, | |
arguments={'ref_tensor': model_output_1})(model_input_2) | |
# MODEL | |
model = tf.keras.models.Model(inputs=[model_input_1, model_input_2], | |
outputs=model_output_2, | |
name="test_model") | |
# ############################################################# # | |
# ######### TRY TO UTILIZE PREDICT WITH DUMMY INPUT ########## # | |
dummy_input = [np.ones((1,160,160,3)), np.zeros((1,160,160,3))] | |
model.predict(x=dummy_input) # >>>>ERROR OCCURS HERE<<<< | |
# ############################################################# # | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment