Skip to content

Instantly share code, notes, and snippets.

@jiahao87
Last active August 26, 2020 11:49
Show Gist options
  • Save jiahao87/f4057cfcc357b8ba2bd24f79cc242915 to your computer and use it in GitHub Desktop.
Save jiahao87/f4057cfcc357b8ba2bd24f79cc242915 to your computer and use it in GitHub Desktop.
Sample code for MLflow
X_train, X_test, y_train, y_test = data_processing()
#################### 1. Setup Experiment ###########################
# set experiment name to organize runs
mlflow.set_experiment('New Experiment Name')
experiment = mlflow.get_experiment_by_name('New Experiment Name')
# set path to log data, e.g., mlruns local folder
mlflow.set_tracking_uri('./mlruns')
# launch new run under the experiment name
with mlflow.start_run(experiment_id = experiment.experiment_id):
#################### 2. Normal Model Training ######################
hyperparams = {'max_depth': 10,
'max_samples': 0.8,
'max_features': 'sqrt'}
clf = RandomForestClassifier(**hyperparams,
random_state=0)
clf.fit(X_train, y_train)
accuracy = clf.score(X_test, y_test)
################ 3. Log params, metrics and model #################
# log model params
mlflow.log_params(hyperparams)
# log model metric
mlflow.log_metric('accuracy', accuracy)
# log model
mlflow.sklearn.log_model(clf, "model")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment