Skip to content

Instantly share code, notes, and snippets.

@doleron
Last active April 9, 2023 00:37
Show Gist options
  • Save doleron/e3fb91f5865d4f0d3673091a0a734600 to your computer and use it in GitHub Desktop.
Save doleron/e3fb91f5865d4f0d3673091a0a734600 to your computer and use it in GitHub Desktop.
INPUT_SHAPE = (INPUT_SIZE, INPUT_SIZE, 3)
def initialize_base_network():
inputs = tf.keras.layers.Input(INPUT_SHAPE)
base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(input_shape=INPUT_SHAPE, include_top=False, weights='imagenet')
base_model.trainable = True
fine_tune_at = len(base_model.layers)-int(len(base_model.layers)*.10)
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
x = base_model(inputs)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs=tf.keras.layers.Dense(64)(x)
model = tf.keras.Model(inputs, outputs)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment