Skip to content

Instantly share code, notes, and snippets.

@erykml
Created May 1, 2022 13:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save erykml/079b43c68ae3b1fb25fc3fe3fc5efc18 to your computer and use it in GitHub Desktop.
Save erykml/079b43c68ae3b1fb25fc3fe3fc5efc18 to your computer and use it in GitHub Desktop.
from config import PROCESSED_IMAGES_DIR, MODELS_DIR
import os
import tensorflow.keras
import mlflow
from dagshub import dagshub_logger
mlflow.set_tracking_uri("https://dagshub.com/eryk.lewinson/mario_vs_wario_v2.mlflow")
os.environ['MLFLOW_TRACKING_USERNAME'] = USER_NAME
os.environ['MLFLOW_TRACKING_PASSWORD'] = PASSWORD
if __name__ == "__main__":
mlflow.tensorflow.autolog()
IMG_SIZE = 128
LR = 0.001
EPOCHS = 10
with mlflow.start_run():
training_set, valid_set, test_set = get_datasets(validation_ratio=0.2,
target_img_size=IMG_SIZE,
batch_size=32)
model = get_model(IMG_SIZE, LR)
print("Training the model...")
model.fit(training_set,
validation_data=valid_set,
epochs = EPOCHS)
print("Training completed.")
print("Evaluating the model...")
test_loss, test_accuracy = model.evaluate(test_set)
print("Evaluating completed.")
# dagshub logger
with dagshub_logger() as logger:
logger.log_metrics(loss=test_loss, accuracy=test_accuracy)
logger.log_hyperparams({
"img_size": IMG_SIZE,
"learning_rate": LR,
"epochs": EPOCHS
})
# mlflow logger
mlflow.log_params({
"img_size": IMG_SIZE,
"learning_rate": LR,
"epochs": EPOCHS
})
mlflow.log_metrics({
"test_set_loss": test_loss,
"test_set_accuracy": test_accuracy,
})
print("Saving the model...")
model.save(MODELS_DIR)
print("done.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment