Skip to content

Instantly share code, notes, and snippets.

@jakob-r
Created February 4, 2021 09:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jakob-r/13155dfeb6e2a9eeb1536b475e794f10 to your computer and use it in GitHub Desktop.
Save jakob-r/13155dfeb6e2a9eeb1536b475e794f10 to your computer and use it in GitHub Desktop.
Partial Least Squares Regression mlr3 learner
library(paradox)
library(R6)
library(mlr3)
# Partial Least Squares Regression
LearnerRegrPls = R6Class("LearnerRegrPls",
inherit = LearnerRegr,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
ps = ps(scale = p_lgl(tags = "train"), center = p_lgl(tags = "train"), ncomp = p_int(tags = c("train", "predict")))
ps$values = list()
super$initialize(
id = "regr.pls",
packages = "pls",
feature_types = c("numeric"),
predict_types = c("response"),
param_set = ps,
properties = NULL
)
}
),
private = list(
.train = function(task) {
# get parameters for training
pars = self$param_set$get_values(tags = "train")
# set column names to ensure consistency in fit and predict
self$state$feature_names = task$feature_names
formula = task$formula()
data = task$data()
# use the mlr3misc::invoke function (it's similar to do.call())
mlr3misc::invoke(pls::plsr,
formula = formula,
data = data,
.args = pars)
},
.predict = function(task) {
# get parameters with tag "predict"
pars = self$param_set$get_values(tags = "predict")
# get newdata and ensure same ordering in train and predict
newdata = task$data(cols = self$state$feature_names)
pred = mlr3misc::invoke(predict, self$model, newdata = newdata, .args = pars)
list(response = pred[,1,1])
}
)
)
mlr_learners$add("regr.pls", LearnerRegrPls)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment