Created
February 4, 2021 09:07
-
-
Save jakob-r/13155dfeb6e2a9eeb1536b475e794f10 to your computer and use it in GitHub Desktop.
Partial Least Squares Regression mlr3 learner
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(paradox) | |
library(R6) | |
library(mlr3) | |
# Partial Least Squares Regression | |
LearnerRegrPls = R6Class("LearnerRegrPls", | |
inherit = LearnerRegr, | |
public = list( | |
#' @description | |
#' Creates a new instance of this [R6][R6::R6Class] class. | |
initialize = function() { | |
ps = ps(scale = p_lgl(tags = "train"), center = p_lgl(tags = "train"), ncomp = p_int(tags = c("train", "predict"))) | |
ps$values = list() | |
super$initialize( | |
id = "regr.pls", | |
packages = "pls", | |
feature_types = c("numeric"), | |
predict_types = c("response"), | |
param_set = ps, | |
properties = NULL | |
) | |
} | |
), | |
private = list( | |
.train = function(task) { | |
# get parameters for training | |
pars = self$param_set$get_values(tags = "train") | |
# set column names to ensure consistency in fit and predict | |
self$state$feature_names = task$feature_names | |
formula = task$formula() | |
data = task$data() | |
# use the mlr3misc::invoke function (it's similar to do.call()) | |
mlr3misc::invoke(pls::plsr, | |
formula = formula, | |
data = data, | |
.args = pars) | |
}, | |
.predict = function(task) { | |
# get parameters with tag "predict" | |
pars = self$param_set$get_values(tags = "predict") | |
# get newdata and ensure same ordering in train and predict | |
newdata = task$data(cols = self$state$feature_names) | |
pred = mlr3misc::invoke(predict, self$model, newdata = newdata, .args = pars) | |
list(response = pred[,1,1]) | |
} | |
) | |
) | |
mlr_learners$add("regr.pls", LearnerRegrPls) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment