Created
May 27, 2024 17:19
-
-
Save mj-ml/7d82ae9a281b82cec10e8c8f0b7ebb9c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pickle | |
import click | |
import mlflow | |
from sklearn.ensemble import RandomForestRegressor | |
from sklearn.metrics import mean_squared_error | |
def load_pickle(filename: str): | |
with open(filename, "rb") as f_in: | |
return pickle.load(f_in) | |
@click.command() | |
@click.option( | |
"--data_path", | |
default="./output", | |
help="Location where the processed NYC taxi trip data was saved" | |
) | |
def run_train(data_path: str): | |
X_train, y_train = load_pickle(os.path.join(data_path, "train.pkl")) | |
X_val, y_val = load_pickle(os.path.join(data_path, "val.pkl")) | |
mlflow.set_tracking_uri("sqlite:///mlflow.db") | |
mlflow.set_experiment("homework") | |
mlflow.autolog(disable=False) | |
with mlflow.start_run(): | |
rf = RandomForestRegressor(max_depth=10, random_state=0) | |
rf.fit(X_train, y_train) | |
y_pred = rf.predict(X_val) | |
rmse = mean_squared_error(y_val, y_pred, squared=False) | |
mlflow.log_metric("rmse", rmse) | |
if __name__ == '__main__': | |
run_train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment