Skip to content

Instantly share code, notes, and snippets.

@pfistfl
Created December 1, 2021 10:14
Show Gist options
  • Save pfistfl/6b190f0612535817bdd33fe8f8bd6548 to your computer and use it in GitHub Desktop.
Save pfistfl/6b190f0612535817bdd33fe8f8bd6548 to your computer and use it in GitHub Desktop.
library("mlr3pipelines")
library("mlr3learners")
library("mlr3misc")
library("paradox")
library("data.table")
PipeOpSplit = R6::R6Class("PipeOpSplit",
inherit = PipeOp,
public = list(
initialize = function(outnum = outnum, id = "split", param_vals = list()) {
outnum = checkmate::assert_int(outnum, lower = 1L)
ps = ParamSet$new(params = list(
ParamUty$new("ids", tags = "train")
))
super$initialize(id,
param_set = ps, param_vals = list(),
input = data.table(name = "input", train = "Task", predict = "Task"),
output = data.table(name = mlr3pipelines:::rep_suffix("output", outnum), train = "Task", predict = "Task"),
tags = "meta"
)
}
),
private = list(
.train = function(inputs) {
self$state = list()
task = inputs[[1L]]
map(self$param_set$values$ids, function(idx) task$clone(deep=TRUE)$filter(idx))
},
.predict = function(inputs) {
rep(inputs, self$outnum)
}
)
)
mlr_pipeops$add("split", PipeOpSplit)
# We use an example task here:
t = tsk("penguins")
# and save the id for the subsets:
ids = list(1:200, 201:344)
# We can now define a learner:
gg = po("split", 2, ids = ids) %>>% # split into two tasks defined by ids
gunion(list(po(id = "lcv1", "learner", lrn("classif.rpart")), po(id = "lcv2", "learner", lrn("classif.rpart")))) %>>% # fit on each task
po("classifavg") # average predictions
gg$train(t)
gg$predict(t)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment