Skip to content

Instantly share code, notes, and snippets.

@jbkunst
Created August 14, 2017 22:28
Show Gist options
  • Save jbkunst/d734b134bdce22d24d5904a23722c47b to your computer and use it in GitHub Desktop.
Save jbkunst/d734b134bdce22d24d5904a23722c47b to your computer and use it in GitHub Desktop.
rm(list = ls())
library(partykit)
library(tidyverse)
iris2 <- iris %>%
tbl_df() %>%
mutate(Species = as.character(Species),
Species = ifelse(Species == "setosa", "versicolor", Species),
Species = as.factor(Species))
irisct <- ctree(Species ~ .,data = iris2, control = ctree_control(mincriterion = 0.2))
partykit:::plot.constparty
plot(irisct)
plot(irisct, inner_panel = node_inner(irisct, pval = FALSE, id = FALSE))
plot(irisct, inner_panel = node_barplot(irisct, id = TRUE))
# ------------------------------------------------------------------------
make_inner_and_barplot <- function(object, ...) {
function(node) {
## layout
pushViewport(viewport(layout = grid.layout(nrow = 2, ncol = 1,
heights = unit(c(0.2, 0.8), "npc"))))
## background color
grid.rect(gp = gpar(fill = "white", col = 0))
## circle
pushViewport(viewport(layout.pos.col = 1, layout.pos.row = 1))
node_inner(object, id = FALSE, pval = FALSE)(node)
popViewport()
## circle
pushViewport(viewport(layout.pos.col = 1, layout.pos.row = 2))
node_barplot(object, id = FALSE, ...)(node)
popViewport(2)
}
}
plot(irisct, inner_panel = make_inner_and_barplot(irisct), tnex = 1, gp = gpar(fontsize = 6))
# ------------------------------------------------------------------------
node_barplot2 <- function(ctreeobj,
col = "black",
fill = NULL,
beside = NULL,
ymax = NULL,
ylines = NULL,
widths = 1,
gap = NULL,
reverse = NULL,
id = TRUE)
{
getMaxPred <- function(x) {
mp <- max(x$prediction)
mpl <- ifelse(x$terminal, 0, getMaxPred(x$left))
mpr <- ifelse(x$terminal, 0, getMaxPred(x$right))
return(max(c(mp, mpl, mpr)))
}
y <- response(ctreeobj)[[1]]
if(is.factor(y) || class(y) == "was_ordered") {
ylevels <- levels(y)
if(is.null(beside)) beside <- if(length(ylevels) < 3) FALSE else TRUE
if(is.null(ymax)) ymax <- if(beside) 1.1 else 1
if(is.null(gap)) gap <- if(beside) 0.1 else 0
} else {
if(is.null(beside)) beside <- FALSE
if(is.null(ymax)) ymax <- getMaxPred(ctreeobj@tree) * 1.1
ylevels <- seq(along = ctreeobj@tree$prediction)
if(length(ylevels) < 2) ylevels <- ""
if(is.null(gap)) gap <- 1
}
if(is.null(reverse)) reverse <- !beside
if(is.null(fill)) fill <- gray.colors(length(ylevels))
if(is.null(ylines)) ylines <- if(beside) c(3, 4) else c(1.5, 2.5)
### panel function for barplots in nodes
rval <- function(node) {
## parameter setup
pred <- node$prediction
if(reverse) {
pred <- rev(pred)
ylevels <- rev(ylevels)
}
np <- length(pred)
nc <- if(beside) np else 1
fill <- rep(fill, length.out = np)
widths <- rep(widths, length.out = nc)
col <- rep(col, length.out = nc)
ylines <- rep(ylines, length.out = 2)
gap <- gap * sum(widths)
yscale <- c(0, ymax)
xscale <- c(0, sum(widths) + (nc+1)*gap)
top_vp <- viewport(layout = grid.layout(nrow = 2, ncol = 3,
widths = unit(c(ylines[1], 1, ylines[2]), c("lines", "null", "lines")),
heights = unit(c(1, 1), c("lines", "null"))),
width = unit(1, "npc"),
height = unit(1, "npc") - unit(2, "lines"),
name = paste("node_barplot", node$nodeID, sep = ""))
pushViewport(top_vp)
grid.rect(gp = gpar(fill = "white", col = 0))
## main title
top <- viewport(layout.pos.col=2, layout.pos.row=1)
pushViewport(top)
mainlab <- paste(ifelse(id, paste("Node", node$nodeID, "(n = "), "n = "),
sum(node$weights), ifelse(id, ")", ""), sep = "")
grid.text(mainlab)
popViewport()
plot <- viewport(layout.pos.col=2, layout.pos.row=2,
xscale=xscale, yscale=yscale,
name = paste("node_barplot", node$nodeID, "plot",
sep = ""))
pushViewport(plot)
if(beside) {
xcenter <- cumsum(widths+gap) - widths/2
for (i in 1:np) {
grid.rect(x = xcenter[i], y = 0, height = pred[i],
width = widths[i],
just = c("center", "bottom"), default.units = "native",
gp = gpar(col = col[i], fill = fill[i]))
}
if(length(xcenter) > 1) grid.xaxis(at = xcenter, label = FALSE)
grid.text(ylevels, x = xcenter, y = unit(-1, "lines"),
just = c("center", "top"),
default.units = "native", check.overlap = TRUE)
grid.yaxis()
} else {
ycenter <- cumsum(pred) - pred
for (i in 1:np) {
grid.rect(x = xscale[2]/2, y = ycenter[i], height = min(pred[i], ymax - ycenter[i]),
width = widths[1],
just = c("center", "bottom"), default.units = "native",
gp = gpar(col = col[i], fill = fill[i]))
}
if(np > 1) {
grid.text(ylevels[1], x = unit(-1, "lines"), y = 0,
just = c("left", "center"), rot = 90,
default.units = "native", check.overlap = TRUE)
grid.text(ylevels[np], x = unit(-1, "lines"), y = ymax,
just = c("right", "center"), rot = 90,
default.units = "native", check.overlap = TRUE)
}
if(np > 2) {
grid.text(ylevels[-c(1,np)], x = unit(-1, "lines"), y = ycenter[-c(1,np)],
just = "center", rot = 90,
default.units = "native", check.overlap = TRUE)
}
grid.yaxis(at = round(1 - pred[i], digits = 3), main = FALSE)
}
grid.rect(gp = gpar(fill = "transparent"))
upViewport(2)
}
return(rval)
}
class(node_barplot2) <- "grapcon_generator"
plot(irisct)
plot(irisct, terminal_panel = node_barplot2)
plot(ct, terminal_panel = node_barplot2,
tp_args = list(ylines = c(2, 4)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment