Skip to content

Instantly share code, notes, and snippets.

@thinhdanggroup
Created March 31, 2018 10:31
Show Gist options
  • Save thinhdanggroup/5e659d8f473b52f27d0e48cd288d24a5 to your computer and use it in GitHub Desktop.
Save thinhdanggroup/5e659d8f473b52f27d0e48cd288d24a5 to your computer and use it in GitHub Desktop.
lcard algorithm
library(ggplot2)
library(utiml)
# Use stdout as per normal...
print("Hello, world!")
fixed_threshold.default <- function(prediction, threshold = 0.5,
probability = FALSE) {
if (length(threshold) == 1) {
threshold <- rep(threshold, ncol(prediction))
}
else if (length(threshold) != ncol(prediction)) {
stop(paste("The threshold values must be a single value or the same",
"number of labels"))
}
bipartition <- do.call(cbind, lapply(seq(ncol(prediction)), function(col) {
as.integer(prediction[, col] >= threshold[col])
}))
dimnames(bipartition) <- dimnames(prediction)
multilabel_prediction(bipartition, prediction, probability)
}
#' @describeIn fixed_threshold Fixed Threshold for mlresult
#' @export
fixed_threshold.mlresult <- function (prediction, threshold = 0.5,
probability = FALSE) {
fixed_threshold.default(as.probability(prediction), threshold, probability)
}
lcard_thresholdk <- function (prediction, cardinality, probability = FALSE) {
UseMethod("lcard_thresholdk")
}
#' @describeIn lcard_threshold Cardinality Threshold for matrix or data.frame
#' @export
lcard_thresholdk.default <- function (prediction, cardinality,
probability = FALSE) {
thresholds <- sort(unique(c(prediction)))
print("TH")
print(thresholds)
best <- which.min(abs(cardinality - sapply(thresholds, function (ts) {
print("row")
print(ts)
print(rowSums(prediction >= ts))
print( mean(rowSums(prediction >= ts)))
mean(rowSums(prediction >= ts))
})))
print("best")
print(best)
print(thresholds[best])
fixed_threshold.default(prediction, thresholds[best], probability)
}
#' @describeIn lcard_threshold Cardinality Threshold for mlresult
#' @export
lcard_thresholdk.mlresult <- function (prediction, cardinality,
probability = FALSE) {
lcard_thresholdk.default(as.probability(prediction), cardinality, probability)
}
prediction <- matrix(runif(16), ncol = 4)
print(prediction)
lcard_thresholdk(prediction, 2.1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment