Skip to content

Instantly share code, notes, and snippets.

@brshallo
Last active December 30, 2022 18:25
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brshallo/37d524b82541c2f8540eab39f991830a to your computer and use it in GitHub Desktop.
Save brshallo/37d524b82541c2f8540eab39f991830a to your computer and use it in GitHub Desktop.
similar to yardstick::conf_mat() but can handle weights
library(dplyr)
#' Confusion Matrix With Observation Weights
#'
#' @param df dataframe
#' @param truth Column that represents 'truth'
#' @param extimate Columns that rrepresents class prediction
#' @param wt Column with observation weights.
#' @param scale_weights_one Whether observations in confusion matrix should equal number of observations.
#' @param dnn Character vector of dimnames for the table
#'
#' @return a confusion matrix
conf_mat_weighted <-
function(df,
truth,
estimate,
wt = NULL,
scale_weights_one = !rlang::quo_is_null(enquo(wt)),
dnn = c("Prediction", "Truth")) {
freq_df <- count(df, {{ truth }}, {{ estimate }}, wt = {{ wt }}, .drop = FALSE)
pred_vals <- unique(freq_df[[1]])
truth_vals <- unique(freq_df[[2]])
output <- matrix(freq_df$n,
nrow = length(pred_vals),
byrow = FALSE,
dimnames =
list(pred_vals, truth_vals) %>%
purrr::set_names(dnn)
)
if(scale_weights_one){
scale_factor <- summarise(df,
sum = sum({{ wt }}),
n = n(),
scale_factor = n / sum) %>%
pull(scale_factor)
output <- output * scale_factor
}
yardstick:::conf_mat.table(output)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment