Skip to content

Instantly share code, notes, and snippets.

@marvinhoxha
Last active August 29, 2022 10:43
Show Gist options
  • Save marvinhoxha/0ed7a1d7d2d4506d8451bb3f01a3ce00 to your computer and use it in GitHub Desktop.
Save marvinhoxha/0ed7a1d7d2d4506d8451bb3f01a3ce00 to your computer and use it in GitHub Desktop.
resnet_body = tf.keras.applications.ResNet50V2(
weights="imagenet",
include_top=False,
input_shape=(int(IMG_SIZE), int(IMG_SIZE), 3),
)
resnet_body.trainable = False
inputs = tf.keras.layers.Input(shape=(int(IMG_SIZE), int(IMG_SIZE), 3))
x = resnet_body(inputs, training=False)
x = tf.keras.layers.Flatten()(x)
outputs = tf.keras.layers.Dense(133, activation="softmax")(x)
resnet_model = tf.keras.Model(inputs, outputs)
resnet_model.compile(
optimizer=tf.optimizers.Adam(learning_rate=float(LR)),
loss=tf.losses.categorical_crossentropy,
metrics=["accuracy"],
)
train_generator = get_train_generator()
valid_generator = get_valid_generator()
resnet_model.fit(
train_generator, epochs=int(EPOCHS), validation_data=valid_generator
)
labels = train_generator.class_indices
logging.info("Dump models.")
if config == "LOCAL":
resnet_model.save("./models/dog_model/1")
with open("./models/labels.pickle", "wb") as handle:
pickle.dump(labels, handle)
elif config == "KUBERNETES":
resnet_model.save("/models/dog_model/1")
with open("/models/labels.pickle", "wb") as handle:
pickle.dump(labels, handle)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment