Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created May 2, 2020 02:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sayakpaul/629c7c7aedf635bd967e6dc0ca0c2a26 to your computer and use it in GitHub Desktop.
Save sayakpaul/629c7c7aedf635bd967e6dc0ca0c2a26 to your computer and use it in GitHub Desktop.
def create_model(img_size=(224,224), num_class=5, train_base=True):
# Accept float16 image inputs
input_layer = Input(shape=(img_size[0],img_size[1],3), dtype=tf.float16)
base = ResNet50(input_tensor=input_layer, include_top=False,
weights="imagenet")
base.trainable = train_base
x = base.output
x = GlobalAveragePooling2D()(x)
# softmax only accepts float32 - need to manually cast (likely a bug)
preds = Dense(num_class, activation="softmax", dtype=tf.float32)(x)
return Model(inputs=input_layer, outputs=preds)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment