Skip to content

Instantly share code, notes, and snippets.

@darkcurrent
Last active January 12, 2019 16:21
Show Gist options
  • Save darkcurrent/0719e56884e6e726a4e9e82d13f580f2 to your computer and use it in GitHub Desktop.
Save darkcurrent/0719e56884e6e726a4e9e82d13f580f2 to your computer and use it in GitHub Desktop.
conv_base = ResNet50(include_top=False,
weights='imagenet',
input_shape=(224, 224, 3))
for layer in conv_base.layers:
layer.trainable = False
x = conv_base.output
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(512, activation='relu')(x)
x = layers.Dense(256, activation='softmax')(x)
predictions = layers.Dense(2, activation='softmax')(x)
model = Model(conv_base.input, predictions)
optimizer = keras.optimizers.RMSprop(lr=1e-4)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment