Skip to content

Instantly share code, notes, and snippets.

@mdneuzerling
Created December 19, 2022 21:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mdneuzerling/363aabdf011524b29995f35016b98bea to your computer and use it in GitHub Desktop.
Save mdneuzerling/363aabdf011524b29995f35016b98bea to your computer and use it in GitHub Desktop.
Tidymodels MLflow R
# This file contains the minimal code needed to set up a tidymodels flavour for
# mlflow, along with unit tests. There are three issues that need to be addressed:
#
# * we currently use workflows:::predict.workflow, which will NOT be accepted by
# CRAN. We should ask that the Tidymodels team make this function available in
# the workflows NAMESPACE.
# * we need to ensure that the Python CLI can support the new R flavour. There is
# a commented-out unit test for this.
# * the unit tests call on packages through `library`, and this could be cleaned up.
#' @rdname mlflow_save_model
#' @export
mlflow_save_model.workflow <- function(workflow,
path,
model_spec = list(),
...) {
if (is.null(workflow$fit$fit)) {
stop("The workflow does not have a model fit. Have you called `fit()` yet?")
}
if (dir.exists(path)) unlink(path, recursive = TRUE)
dir.create(path)
saveRDS(workflow, file.path(path, "workflow.rds"))
spec <- workflows::pull_workflow_spec(workflow)
model <- class(spec)[[1]] # adapted from workflows:::print_header
engine <- spec$engine
mode <- spec$mode
if (engine == "spark") {
stop("Spark models are not supported through the workflow flavor")
}
model_spec$flavors <- append(model_spec$flavors, list(
workflow = list(
data = "workflow.rds",
model = model,
engine = engine,
mode = mode
)
))
mlflow_write_model_spec(path, model_spec)
model_spec
}
#' @rdname mlflow_load_flavor
#' @export
mlflow_load_flavor.mlflow_flavor_workflow <- function(flavor, model_path) {
require_package("workflows")
model_spec <- mlflow_read_model_spec("model")
model <- model_spec$flavors$workflow$model
engine <- model_spec$flavors$workflow$engine
require_parsnip_dependencies(model, engine)
readRDS(file.path(model_path, "workflow.rds"))
}
# TODO
# CRAN will not accept a package that uses an internal function of another
# package like this. We would need to ask the Tidymodels team to export
# workflows::predict.workflow
#' @rdname mlflow_predict
#' @export
mlflow_predict.workflow <- function(model, data, ...) {
workflows:::predict.workflow(model, data, ...)
}
#' Check that the package dependencies for a parsnip model are satisfied
#'
#' @inheritParams get_parsnip_dependencies
#'
#' @return Invisibly returns a character vector of required packages
#'
#' @keywords internal
require_parsnip_dependencies <- function(model, engine) {
required_packages <- get_parsnip_dependencies(model, engine)
for (package in required_packages) {
require_package(package)
}
invisible(required_packages)
}
#' Determine the package dependencies for a parsnip model
#'
#' @param model The type of model, eg. "linear_reg". This is stored as a class
#' of the model object, and is also recorded in the saved mlflow artefact.
#' @param engine The engine for the model, as understood by parsnip, eg. "lm".
#'
#' @return A character vector of required packages
#'
#' @keywords internal
get_parsnip_dependencies <- function(model, engine) {
dependencies <- parsnip::get_dependency(model)
if (!(engine %in% dependencies$engine)) {
stop(engine, " is not a valid engine for ", model)
}
engine_dependencies <- dependencies[which(dependencies$engine == engine), ]
engine_dependencies$pkg[[1]]
}
#' @keywords internal
require_package <- function(package) {
if (!requireNamespace(package, quietly = TRUE)) {
stop("The '", package, "' package must be installed.")
}
}
##############
# Unit tests #
##############
# TODO: get rid of these library calls and refer to the package namespaces directly
library(parsnip)
library(workflows)
library(recipes)
test_that("mlflow can save, load and predict linear model workflow", {
idx <- withr::with_seed(3809, sample(nrow(mtcars)))
train <- mtcars[idx[1:25], ]
test <- mtcars[idx[26:32], ]
lm_workflow_fit <- lm_workflow %>% fit(train)
lm_workflow <- workflow() %>%
add_model(lm_parsnip) %>%
add_recipe(recipe(mpg ~ ., mtcars) %>% step_log(wt))
mlflow_clear_test_dir("model")
mlflow_save_model(lm_workflow_fit, "model")
expect_true(dir.exists("model"))
loaded_back_model <- mlflow_load_model("model")
prediction <- mlflow_predict(loaded_back_model, test)
expect_equal(
prediction,
predict(lm_workflow_fit, test)
)
})
test_that("mlflow can save, load and predict random forest workflow", {
idx <- withr::with_seed(3809, sample(nrow(mtcars)))
train <- mtcars[idx[1:25], ]
test <- mtcars[idx[26:32], ]
rf_workflow <- workflow() %>%
add_model(rf_parsnip) %>%
add_recipe(recipe(mpg ~ ., mtcars) %>% step_log(wt))
rf_workflow_fit <- rf_workflow %>% fit(train)
mlflow_clear_test_dir("model")
mlflow_save_model(rf_workflow_fit, "model")
expect_true(dir.exists("model"))
loaded_back_model <- mlflow_load_model("model")
prediction <- mlflow_predict(loaded_back_model, test)
expect_equal(
prediction,
predict(rf_workflow_fit, test)
)
})
test_that("can predict with the mlflow_rfunc_serve", {
model_server <- processx::process$new(
"Rscript",
c(
"-e",
"mlflow::mlflow_rfunc_serve('model', browse = FALSE)"
),
supervise = TRUE,
stdout = "|",
stderr = "|"
)
teardown(model_server$kill())
Sys.sleep(10)
expect_true(model_server$is_alive())
http_prediction <- httr::content(
httr::POST(
"http://127.0.0.1:8090/predict/",
body = jsonlite::toJSON(as.list(test))
)
)
expect_equal(
purrr::flatten_dbl(purrr:::flatten(http_prediction)),
predict(rf_workflow_fit, test)$.pred
)
})
# Disabled until I can work out how to update the Python side of things
# test_that("can predict with CLI", {
# temp_in_csv <- tempfile(fileext = ".csv")
# temp_out <- tempfile(fileext = ".json")
# write.csv(test$data, temp_in_csv, row.names = FALSE)
# mlflow_cli(
# "models", "predict", "-m", "model", "-i", temp_in_csv,
# "-o", temp_out, "-t", "csv"
# )
# prediction <- unlist(jsonlite::read_json(temp_out))
# expect_true(!is.null(prediction))
# expect_equal(prediction, predict(rf_workflow_fit, test))
#
# temp_in_json <- tempfile(fileext = ".json")
# jsonlite::write_json(test$data, temp_in_json)
# mlflow_cli(
# "models", "predict", "-m", "model", "-i", temp_in_json, "-o", temp_out,
# "-t", "json",
# "--json-format", "records"
# )
# prediction <- unlist(jsonlite::read_json(temp_out))
# expect_true(!is.null(prediction))
# expect_equal(prediction, unname(predict(rf_workflow_fit, test)))
# })
test_that("model and engine dependencies are detected", {
expect_equal(
get_parsnip_dependencies("linear_reg", "lm"),
"stats"
)
expect_equal(
get_parsnip_dependencies("linear_reg", "stan"),
"rstanarm"
)
expect_equal(
get_parsnip_dependencies("linear_reg", "keras"),
c("magrittr", "keras")
)
expect_equal(
get_parsnip_dependencies("rand_forest", "ranger"),
"ranger"
)
expect_equal(
get_parsnip_dependencies("rand_forest", "randomForest"),
"randomForest"
)
expect_error(
get_parsnip_dependencies("linear_reg", "not_an_engine"),
"not_an_engine is not a valid engine for linear_reg"
)
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment