Skip to content

Instantly share code, notes, and snippets.

@mj-ml
Created May 27, 2024 17:19
Show Gist options
  • Save mj-ml/7d82ae9a281b82cec10e8c8f0b7ebb9c to your computer and use it in GitHub Desktop.
Save mj-ml/7d82ae9a281b82cec10e8c8f0b7ebb9c to your computer and use it in GitHub Desktop.
import os
import pickle
import click
import mlflow
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
def load_pickle(filename: str):
with open(filename, "rb") as f_in:
return pickle.load(f_in)
@click.command()
@click.option(
"--data_path",
default="./output",
help="Location where the processed NYC taxi trip data was saved"
)
def run_train(data_path: str):
X_train, y_train = load_pickle(os.path.join(data_path, "train.pkl"))
X_val, y_val = load_pickle(os.path.join(data_path, "val.pkl"))
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("homework")
mlflow.autolog(disable=False)
with mlflow.start_run():
rf = RandomForestRegressor(max_depth=10, random_state=0)
rf.fit(X_train, y_train)
y_pred = rf.predict(X_val)
rmse = mean_squared_error(y_val, y_pred, squared=False)
mlflow.log_metric("rmse", rmse)
if __name__ == '__main__':
run_train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment