Skip to content

Instantly share code, notes, and snippets.

@priyanlc
Created July 3, 2020 08:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save priyanlc/0fd8c2b2607de6ccac22a6bf6479e5b5 to your computer and use it in GitHub Desktop.
Save priyanlc/0fd8c2b2607de6ccac22a6bf6479e5b5 to your computer and use it in GitHub Desktop.
def mlflow_run(self, name='n-york-taxi-test-run'):
"""
:param name: Name of the run to be logged by MLflow
:return: Tuple (ExperimentID, runID)
"""
with mlflow.start_run(run_name=name) as run:
# retrieve current run and experiment id
runID = run.info.run_uuid
experimentID = run.info.experiment_id
# preprocess data, create model and train/evaluate
# with training and validation data
self.load_scale_and_preprocess_data()
self.create_model()
self.train_and_evaluate()
_params = self.params[0]
metrics = self._metrics[0]
mae = metrics['mean_absolute_error']
mse = metrics['loss']
# compute regression evaluation metrics
rmse = np.sqrt(mse)
# ***************************** MLflow Tracking Start
# Log input parameters
mlflow.log_params(_params)
# Log metrics
mlflow.log_metric("mae", mae)
mlflow.log_metric("mse", mse)
mlflow.log_metric("rmse", rmse)
# Log model
mlflow.keras.log_model(self.model, "Keras_model for NY-Taxi dataset")
# ***************************** MLflow Tracking End
print("MLflow Run with run_id {} and experiment_id {}".format(runID, experimentID))
print('Mean Absolute Error :', mae)
print('Mean Squared Error :', mse)
print('Root Mean Squared Error:', rmse)
return (experimentID, runID)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment