Created
December 1, 2021 10:14
-
-
Save pfistfl/6b190f0612535817bdd33fe8f8bd6548 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library("mlr3pipelines") | |
library("mlr3learners") | |
library("mlr3misc") | |
library("paradox") | |
library("data.table") | |
PipeOpSplit = R6::R6Class("PipeOpSplit", | |
inherit = PipeOp, | |
public = list( | |
initialize = function(outnum = outnum, id = "split", param_vals = list()) { | |
outnum = checkmate::assert_int(outnum, lower = 1L) | |
ps = ParamSet$new(params = list( | |
ParamUty$new("ids", tags = "train") | |
)) | |
super$initialize(id, | |
param_set = ps, param_vals = list(), | |
input = data.table(name = "input", train = "Task", predict = "Task"), | |
output = data.table(name = mlr3pipelines:::rep_suffix("output", outnum), train = "Task", predict = "Task"), | |
tags = "meta" | |
) | |
} | |
), | |
private = list( | |
.train = function(inputs) { | |
self$state = list() | |
task = inputs[[1L]] | |
map(self$param_set$values$ids, function(idx) task$clone(deep=TRUE)$filter(idx)) | |
}, | |
.predict = function(inputs) { | |
rep(inputs, self$outnum) | |
} | |
) | |
) | |
mlr_pipeops$add("split", PipeOpSplit) | |
# We use an example task here: | |
t = tsk("penguins") | |
# and save the id for the subsets: | |
ids = list(1:200, 201:344) | |
# We can now define a learner: | |
gg = po("split", 2, ids = ids) %>>% # split into two tasks defined by ids | |
gunion(list(po(id = "lcv1", "learner", lrn("classif.rpart")), po(id = "lcv2", "learner", lrn("classif.rpart")))) %>>% # fit on each task | |
po("classifavg") # average predictions | |
gg$train(t) | |
gg$predict(t) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment