Created
December 19, 2022 21:08
-
-
Save mdneuzerling/363aabdf011524b29995f35016b98bea to your computer and use it in GitHub Desktop.
Tidymodels MLflow R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This file contains the minimal code needed to set up a tidymodels flavour for | |
# mlflow, along with unit tests. There are three issues that need to be addressed: | |
# | |
# * we currently use workflows:::predict.workflow, which will NOT be accepted by | |
# CRAN. We should ask that the Tidymodels team make this function available in | |
# the workflows NAMESPACE. | |
# * we need to ensure that the Python CLI can support the new R flavour. There is | |
# a commented-out unit test for this. | |
# * the unit tests call on packages through `library`, and this could be cleaned up. | |
#' @rdname mlflow_save_model | |
#' @export | |
mlflow_save_model.workflow <- function(workflow, | |
path, | |
model_spec = list(), | |
...) { | |
if (is.null(workflow$fit$fit)) { | |
stop("The workflow does not have a model fit. Have you called `fit()` yet?") | |
} | |
if (dir.exists(path)) unlink(path, recursive = TRUE) | |
dir.create(path) | |
saveRDS(workflow, file.path(path, "workflow.rds")) | |
spec <- workflows::pull_workflow_spec(workflow) | |
model <- class(spec)[[1]] # adapted from workflows:::print_header | |
engine <- spec$engine | |
mode <- spec$mode | |
if (engine == "spark") { | |
stop("Spark models are not supported through the workflow flavor") | |
} | |
model_spec$flavors <- append(model_spec$flavors, list( | |
workflow = list( | |
data = "workflow.rds", | |
model = model, | |
engine = engine, | |
mode = mode | |
) | |
)) | |
mlflow_write_model_spec(path, model_spec) | |
model_spec | |
} | |
#' @rdname mlflow_load_flavor | |
#' @export | |
mlflow_load_flavor.mlflow_flavor_workflow <- function(flavor, model_path) { | |
require_package("workflows") | |
model_spec <- mlflow_read_model_spec("model") | |
model <- model_spec$flavors$workflow$model | |
engine <- model_spec$flavors$workflow$engine | |
require_parsnip_dependencies(model, engine) | |
readRDS(file.path(model_path, "workflow.rds")) | |
} | |
# TODO | |
# CRAN will not accept a package that uses an internal function of another | |
# package like this. We would need to ask the Tidymodels team to export | |
# workflows::predict.workflow | |
#' @rdname mlflow_predict | |
#' @export | |
mlflow_predict.workflow <- function(model, data, ...) { | |
workflows:::predict.workflow(model, data, ...) | |
} | |
#' Check that the package dependencies for a parsnip model are satisfied | |
#' | |
#' @inheritParams get_parsnip_dependencies | |
#' | |
#' @return Invisibly returns a character vector of required packages | |
#' | |
#' @keywords internal | |
require_parsnip_dependencies <- function(model, engine) { | |
required_packages <- get_parsnip_dependencies(model, engine) | |
for (package in required_packages) { | |
require_package(package) | |
} | |
invisible(required_packages) | |
} | |
#' Determine the package dependencies for a parsnip model | |
#' | |
#' @param model The type of model, eg. "linear_reg". This is stored as a class | |
#' of the model object, and is also recorded in the saved mlflow artefact. | |
#' @param engine The engine for the model, as understood by parsnip, eg. "lm". | |
#' | |
#' @return A character vector of required packages | |
#' | |
#' @keywords internal | |
get_parsnip_dependencies <- function(model, engine) { | |
dependencies <- parsnip::get_dependency(model) | |
if (!(engine %in% dependencies$engine)) { | |
stop(engine, " is not a valid engine for ", model) | |
} | |
engine_dependencies <- dependencies[which(dependencies$engine == engine), ] | |
engine_dependencies$pkg[[1]] | |
} | |
#' @keywords internal | |
require_package <- function(package) { | |
if (!requireNamespace(package, quietly = TRUE)) { | |
stop("The '", package, "' package must be installed.") | |
} | |
} | |
############## | |
# Unit tests # | |
############## | |
# TODO: get rid of these library calls and refer to the package namespaces directly | |
library(parsnip) | |
library(workflows) | |
library(recipes) | |
test_that("mlflow can save, load and predict linear model workflow", { | |
idx <- withr::with_seed(3809, sample(nrow(mtcars))) | |
train <- mtcars[idx[1:25], ] | |
test <- mtcars[idx[26:32], ] | |
lm_workflow_fit <- lm_workflow %>% fit(train) | |
lm_workflow <- workflow() %>% | |
add_model(lm_parsnip) %>% | |
add_recipe(recipe(mpg ~ ., mtcars) %>% step_log(wt)) | |
mlflow_clear_test_dir("model") | |
mlflow_save_model(lm_workflow_fit, "model") | |
expect_true(dir.exists("model")) | |
loaded_back_model <- mlflow_load_model("model") | |
prediction <- mlflow_predict(loaded_back_model, test) | |
expect_equal( | |
prediction, | |
predict(lm_workflow_fit, test) | |
) | |
}) | |
test_that("mlflow can save, load and predict random forest workflow", { | |
idx <- withr::with_seed(3809, sample(nrow(mtcars))) | |
train <- mtcars[idx[1:25], ] | |
test <- mtcars[idx[26:32], ] | |
rf_workflow <- workflow() %>% | |
add_model(rf_parsnip) %>% | |
add_recipe(recipe(mpg ~ ., mtcars) %>% step_log(wt)) | |
rf_workflow_fit <- rf_workflow %>% fit(train) | |
mlflow_clear_test_dir("model") | |
mlflow_save_model(rf_workflow_fit, "model") | |
expect_true(dir.exists("model")) | |
loaded_back_model <- mlflow_load_model("model") | |
prediction <- mlflow_predict(loaded_back_model, test) | |
expect_equal( | |
prediction, | |
predict(rf_workflow_fit, test) | |
) | |
}) | |
test_that("can predict with the mlflow_rfunc_serve", { | |
model_server <- processx::process$new( | |
"Rscript", | |
c( | |
"-e", | |
"mlflow::mlflow_rfunc_serve('model', browse = FALSE)" | |
), | |
supervise = TRUE, | |
stdout = "|", | |
stderr = "|" | |
) | |
teardown(model_server$kill()) | |
Sys.sleep(10) | |
expect_true(model_server$is_alive()) | |
http_prediction <- httr::content( | |
httr::POST( | |
"http://127.0.0.1:8090/predict/", | |
body = jsonlite::toJSON(as.list(test)) | |
) | |
) | |
expect_equal( | |
purrr::flatten_dbl(purrr:::flatten(http_prediction)), | |
predict(rf_workflow_fit, test)$.pred | |
) | |
}) | |
# Disabled until I can work out how to update the Python side of things | |
# test_that("can predict with CLI", { | |
# temp_in_csv <- tempfile(fileext = ".csv") | |
# temp_out <- tempfile(fileext = ".json") | |
# write.csv(test$data, temp_in_csv, row.names = FALSE) | |
# mlflow_cli( | |
# "models", "predict", "-m", "model", "-i", temp_in_csv, | |
# "-o", temp_out, "-t", "csv" | |
# ) | |
# prediction <- unlist(jsonlite::read_json(temp_out)) | |
# expect_true(!is.null(prediction)) | |
# expect_equal(prediction, predict(rf_workflow_fit, test)) | |
# | |
# temp_in_json <- tempfile(fileext = ".json") | |
# jsonlite::write_json(test$data, temp_in_json) | |
# mlflow_cli( | |
# "models", "predict", "-m", "model", "-i", temp_in_json, "-o", temp_out, | |
# "-t", "json", | |
# "--json-format", "records" | |
# ) | |
# prediction <- unlist(jsonlite::read_json(temp_out)) | |
# expect_true(!is.null(prediction)) | |
# expect_equal(prediction, unname(predict(rf_workflow_fit, test))) | |
# }) | |
test_that("model and engine dependencies are detected", { | |
expect_equal( | |
get_parsnip_dependencies("linear_reg", "lm"), | |
"stats" | |
) | |
expect_equal( | |
get_parsnip_dependencies("linear_reg", "stan"), | |
"rstanarm" | |
) | |
expect_equal( | |
get_parsnip_dependencies("linear_reg", "keras"), | |
c("magrittr", "keras") | |
) | |
expect_equal( | |
get_parsnip_dependencies("rand_forest", "ranger"), | |
"ranger" | |
) | |
expect_equal( | |
get_parsnip_dependencies("rand_forest", "randomForest"), | |
"randomForest" | |
) | |
expect_error( | |
get_parsnip_dependencies("linear_reg", "not_an_engine"), | |
"not_an_engine is not a valid engine for linear_reg" | |
) | |
}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment