Skip to content

Instantly share code, notes, and snippets.

@hammady
Created February 17, 2014 07:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hammady/9046168 to your computer and use it in GitHub Desktop.
Save hammady/9046168 to your computer and use it in GitHub Desktop.
After sourcing this patch, you can use "wss" as a new measure in plotting performance curves, same like fp, tp, rec, auc, ... You should have ROCR installed and imported first
# Patch ROCR R package to support plotting of WSS (Work Saved over Sampling)
# Author: Hossam Hammady (hhammady@qf.org.qa)
# Organization: Qatar Computing Research Institute; Data Analytics Group (http://da.qcri.qa)
# Date: 17-Feb-2014
# License: MIT
# Description: After sourcing this patch, you can use "wss" as a new measure in
# .. plotting performance curves, same like fp, tp, rec, auc, ...
# .. You should have ROCR installed and imported first
original.define.environments <- .define.environments
.define.environments <- function()
{
# get original environments
envir.list <- original.define.environments()
long.unit.names = envir.list$long.unit.names
function.names = envir.list$function.names
obligatory.x.axis = envir.list$obligatory.x.axis
optional.arguments = envir.list$optional.arguments
default.values = envir.list$default.values
.performance.wss <- function(predictions, labels, cutoffs,
fp, tp, fn, tn, n.pos, n.neg, n.pos.pred, n.neg.pred) # append any optional arguments
{
list(cutoffs, (tn+fn)/(n.pos+n.neg) - 1 + tp/(tp+fn))
}
assign("wss", "Work Saved over Random Sampling",
envir = long.unit.names)
assign("wss", .performance.wss,
envir = function.names)
list(
long.unit.names = long.unit.names,
function.names = function.names,
obligatory.x.axis = obligatory.x.axis,
optional.arguments = optional.arguments,
default.values = default.values
)
}
# copied from original function
performance <- function (prediction.obj, measure, x.measure = "cutoff", ...)
{
envir.list <- .define.environments()
long.unit.names <- envir.list$long.unit.names
function.names <- envir.list$function.names
obligatory.x.axis <- envir.list$obligatory.x.axis
optional.arguments <- envir.list$optional.arguments
default.values <- envir.list$default.values
if (class(prediction.obj) != "prediction" || !exists(measure,
where = long.unit.names, inherits = FALSE) || !exists(x.measure,
where = long.unit.names, inherits = FALSE)) {
stop(paste("Wrong argument types: First argument must be of type",
"'prediction'; second and optional third argument must",
"be available performance measures!"))
}
if (exists(x.measure, where = obligatory.x.axis, inherits = FALSE)) {
message <- paste("The performance measure", x.measure,
"can only be used as 'measure', because it has",
"the following obligatory 'x.measure':\n", get(x.measure,
envir = obligatory.x.axis))
stop(message)
}
if (exists(measure, where = obligatory.x.axis, inherits = FALSE)) {
x.measure <- get(measure, envir = obligatory.x.axis)
}
if (x.measure == "cutoff" || exists(measure, where = obligatory.x.axis,
inherits = FALSE)) {
optional.args <- list(...)
argnames <- c()
if (exists(measure, where = optional.arguments, inherits = FALSE)) {
argnames <- get(measure, envir = optional.arguments)
default.arglist <- list()
for (i in 1:length(argnames)) {
default.arglist <- c(default.arglist, get(paste(measure,
":", argnames[i], sep = ""), envir = default.values,
inherits = FALSE))
}
names(default.arglist) <- argnames
for (i in 1:length(argnames)) {
templist <- list(optional.args, default.arglist[[i]])
names(templist) <- c("arglist", argnames[i])
optional.args <- do.call(".farg", templist)
}
}
optional.args <- .select.args(optional.args, argnames)
function.name <- get(measure, envir = function.names)
x.values <- list()
y.values <- list()
for (i in 1:length(prediction.obj@predictions)) {
argumentlist <- .sarg(optional.args, predictions = prediction.obj@predictions[[i]],
labels = prediction.obj@labels[[i]], cutoffs = prediction.obj@cutoffs[[i]],
fp = prediction.obj@fp[[i]], tp = prediction.obj@tp[[i]],
fn = prediction.obj@fn[[i]], tn = prediction.obj@tn[[i]],
n.pos = prediction.obj@n.pos[[i]], n.neg = prediction.obj@n.neg[[i]],
n.pos.pred = prediction.obj@n.pos.pred[[i]],
n.neg.pred = prediction.obj@n.neg.pred[[i]])
ans <- do.call(function.name, argumentlist)
if (!is.null(ans[[1]]))
x.values <- c(x.values, list(ans[[1]]))
y.values <- c(y.values, list(ans[[2]]))
}
if (!(length(x.values) == 0 || length(x.values) == length(y.values))) {
stop("Consistency error.")
}
return(new("performance", x.name = get(x.measure, envir = long.unit.names),
y.name = get(measure, envir = long.unit.names), alpha.name = "none",
x.values = x.values, y.values = y.values, alpha.values = list()))
}
else {
perf.obj.1 <- performance(prediction.obj, measure = x.measure,
...)
perf.obj.2 <- performance(prediction.obj, measure = measure,
...)
return(.combine.performance.objects(perf.obj.1, perf.obj.2))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment