Skip to content

Instantly share code, notes, and snippets.

@ayulockin
Created August 30, 2022 17:29
Show Gist options
  • Save ayulockin/ffc276cdcdbb6f9f95d3affe72a82f8c to your computer and use it in GitHub Desktop.
Save ayulockin/ffc276cdcdbb6f9f95d3affe72a82f8c to your computer and use it in GitHub Desktop.
import tensorflow as tf
import wandb
from wandb.keras import (
# WandBMetricsLogger,
WandbModelCheckpoint,
# WandbGradientLogger,
# ModelLogger,
# FLOPsLogger,
)
with wandb.init(project="mnist", job_type="dev-wandb-metrics-logger"):
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
model = tf.keras.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10),
]
)
model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
callbacks = [
# WandBMetricsLogger(log_batch_frequency=10),
WandbModelCheckpoint(
filepath="../model/model_{epoch}",
save_best_only=False,
save_weights_only=False,
save_freq="epoch",
)
]
model.fit(
train_images,
train_labels,
validation_data=(test_images, test_labels),
epochs=10,
callbacks=callbacks,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment