After sourcing this patch, you can use "wss" as a new measure in plotting performance curves, same like fp, tp, rec, auc, ... You should have ROCR installed and imported first
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Patch ROCR R package to support plotting of WSS (Work Saved over Sampling) | |
# Author: Hossam Hammady (hhammady@qf.org.qa) | |
# Organization: Qatar Computing Research Institute; Data Analytics Group (http://da.qcri.qa) | |
# Date: 17-Feb-2014 | |
# License: MIT | |
# Description: After sourcing this patch, you can use "wss" as a new measure in | |
# .. plotting performance curves, same like fp, tp, rec, auc, ... | |
# .. You should have ROCR installed and imported first | |
original.define.environments <- .define.environments | |
.define.environments <- function() | |
{ | |
# get original environments | |
envir.list <- original.define.environments() | |
long.unit.names = envir.list$long.unit.names | |
function.names = envir.list$function.names | |
obligatory.x.axis = envir.list$obligatory.x.axis | |
optional.arguments = envir.list$optional.arguments | |
default.values = envir.list$default.values | |
.performance.wss <- function(predictions, labels, cutoffs, | |
fp, tp, fn, tn, n.pos, n.neg, n.pos.pred, n.neg.pred) # append any optional arguments | |
{ | |
list(cutoffs, (tn+fn)/(n.pos+n.neg) - 1 + tp/(tp+fn)) | |
} | |
assign("wss", "Work Saved over Random Sampling", | |
envir = long.unit.names) | |
assign("wss", .performance.wss, | |
envir = function.names) | |
list( | |
long.unit.names = long.unit.names, | |
function.names = function.names, | |
obligatory.x.axis = obligatory.x.axis, | |
optional.arguments = optional.arguments, | |
default.values = default.values | |
) | |
} | |
# copied from original function | |
performance <- function (prediction.obj, measure, x.measure = "cutoff", ...) | |
{ | |
envir.list <- .define.environments() | |
long.unit.names <- envir.list$long.unit.names | |
function.names <- envir.list$function.names | |
obligatory.x.axis <- envir.list$obligatory.x.axis | |
optional.arguments <- envir.list$optional.arguments | |
default.values <- envir.list$default.values | |
if (class(prediction.obj) != "prediction" || !exists(measure, | |
where = long.unit.names, inherits = FALSE) || !exists(x.measure, | |
where = long.unit.names, inherits = FALSE)) { | |
stop(paste("Wrong argument types: First argument must be of type", | |
"'prediction'; second and optional third argument must", | |
"be available performance measures!")) | |
} | |
if (exists(x.measure, where = obligatory.x.axis, inherits = FALSE)) { | |
message <- paste("The performance measure", x.measure, | |
"can only be used as 'measure', because it has", | |
"the following obligatory 'x.measure':\n", get(x.measure, | |
envir = obligatory.x.axis)) | |
stop(message) | |
} | |
if (exists(measure, where = obligatory.x.axis, inherits = FALSE)) { | |
x.measure <- get(measure, envir = obligatory.x.axis) | |
} | |
if (x.measure == "cutoff" || exists(measure, where = obligatory.x.axis, | |
inherits = FALSE)) { | |
optional.args <- list(...) | |
argnames <- c() | |
if (exists(measure, where = optional.arguments, inherits = FALSE)) { | |
argnames <- get(measure, envir = optional.arguments) | |
default.arglist <- list() | |
for (i in 1:length(argnames)) { | |
default.arglist <- c(default.arglist, get(paste(measure, | |
":", argnames[i], sep = ""), envir = default.values, | |
inherits = FALSE)) | |
} | |
names(default.arglist) <- argnames | |
for (i in 1:length(argnames)) { | |
templist <- list(optional.args, default.arglist[[i]]) | |
names(templist) <- c("arglist", argnames[i]) | |
optional.args <- do.call(".farg", templist) | |
} | |
} | |
optional.args <- .select.args(optional.args, argnames) | |
function.name <- get(measure, envir = function.names) | |
x.values <- list() | |
y.values <- list() | |
for (i in 1:length(prediction.obj@predictions)) { | |
argumentlist <- .sarg(optional.args, predictions = prediction.obj@predictions[[i]], | |
labels = prediction.obj@labels[[i]], cutoffs = prediction.obj@cutoffs[[i]], | |
fp = prediction.obj@fp[[i]], tp = prediction.obj@tp[[i]], | |
fn = prediction.obj@fn[[i]], tn = prediction.obj@tn[[i]], | |
n.pos = prediction.obj@n.pos[[i]], n.neg = prediction.obj@n.neg[[i]], | |
n.pos.pred = prediction.obj@n.pos.pred[[i]], | |
n.neg.pred = prediction.obj@n.neg.pred[[i]]) | |
ans <- do.call(function.name, argumentlist) | |
if (!is.null(ans[[1]])) | |
x.values <- c(x.values, list(ans[[1]])) | |
y.values <- c(y.values, list(ans[[2]])) | |
} | |
if (!(length(x.values) == 0 || length(x.values) == length(y.values))) { | |
stop("Consistency error.") | |
} | |
return(new("performance", x.name = get(x.measure, envir = long.unit.names), | |
y.name = get(measure, envir = long.unit.names), alpha.name = "none", | |
x.values = x.values, y.values = y.values, alpha.values = list())) | |
} | |
else { | |
perf.obj.1 <- performance(prediction.obj, measure = x.measure, | |
...) | |
perf.obj.2 <- performance(prediction.obj, measure = measure, | |
...) | |
return(.combine.performance.objects(perf.obj.1, perf.obj.2)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment