Skip to content

Instantly share code, notes, and snippets.

@slopp
Created August 11, 2016 06:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save slopp/82272f00993c28249816a0024f0d60e6 to your computer and use it in GitHub Desktop.
Save slopp/82272f00993c28249816a0024f0d60e6 to your computer and use it in GitHub Desktop.
#' Spark ML - Binary Classifier Area under ROC
#'
#' @param predicted_tbl_spark The result of running sdf_predict
#' @param label A character string specifying which column contains the true, indexed labels (0 / 1)
#' @param score A characger string specifying which column contains the scored probability of a 1
#'
#' @return The area under the ROC curve.
#' @export
#'
ml_auc_roc <- function(predicted_tbl_spark, label, score){
df <- spark_dataframe(predicted_tbl_spark)
sc <- spark_connection(df)
envir <- new.env(parent = emptyenv())
tdf <- df %>% ml_prepare_dataframe(response = label, feature = c(score, score), envir = envir)
auc_roc <- invoke_new(sc, "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator") %>%
invoke("setLabelCol", envir$response) %>%
invoke("setRawPredictionCol", envir$features) %>%
invoke("setMetricName", "areaUnderROC") %>%
invoke("evaluate", tdf)
return(auc_roc)
}
#' Spark ML - Classifier Accuracy
#'
#' @param predicted_tbl_spark The result of running sdf_predict
#' @param label A string specifying the column that contains the true, indexed label. Support for binary and multi-class labels, column should be of double type (use as.double)
#' @param predicted_lbl A string specifying the column that contains the predicted label NOT the scored probability. Support for binary and multi-class labels, column should be of double type (use as.double)
#'
#' @return
#' @export
#'
#' @examples
ml_accuracy <- function(predicted_tbl_spark, label, predicted_lbl){
df <- spark_dataframe(predicted_tbl_spark)
sc <- spark_connection(df)
accuracy <- invoke_new(sc, "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator") %>%
invoke("setLabelCol", label) %>%
invoke("setPredictionCol", predicted_lbl) %>%
invoke("setMetricName", "accuracy") %>%
invoke("evaluate", df)
return(accuracy)
}
#' Spark ML - Feature Importance for Tree Models
#'
#' @param model An ml_model object, support for decision trees (>1.5.0), random forest (>2.0.0), GBT (>2.0.0)
#'
#' @return A sorted data frame with feature labels and their relative importance.
#' @export
#'
#' @examples
ml_tree_feature_importance <- function(model){
supported <- c("ml_model_gradient_boosted_trees",
"ml_model_decision_tree",
"ml_model_random_forest")
if ( !(class(model)[1] %in% supported)) {
stop("Supported models include: ", paste(supported, collapse = ", "))
}
if (class(model) != "ml_model_decision_tree") spark_require_version(sc, "2.0.0")
importance <- invoke(model$.model,"featureImportances") %>%
invoke("toArray") %>%
cbind(model$features) %>%
as.data.frame()
colnames(importance) <- c("importance", "feature")
importance %>% arrange(desc(importance))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment