Skip to content

Instantly share code, notes, and snippets.

@bubbobne
Last active August 10, 2020 09:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bubbobne/d92523258b2883cc90229e14ed5009ba to your computer and use it in GitHub Desktop.
Save bubbobne/d92523258b2883cc90229e14ed5009ba to your computer and use it in GitHub Desktop.
library(tidyverse)
library(caret)
#' Predict values from an ordinal brm regression model.
#'
#' @param model the fitted model.
#' @param data a datafram with all explanatory variables and labels.
#' @return a dataframe with a probability value for each labels, for each item, and true value.
#' @examples
#' add(1, 1)
#' add(10, 1)
getPredictions <- function(model ,data, labeled_column) {
output = predict(
model,
newdata = data,
re_formula = NA,
)
labels = trimws(gsub('P\\(Y =', "", gsub('\\)','',colnames(output))))
output = as.data.frame(output)
colnames(output)=labels
output = output %>% rownames_to_column() %>% gather(predict_label, value, -rowname) %>% group_by(rowname) %>% filter(rank(-value) == 1)
data = data %>% rownames_to_column() %>% select (rowname,all_of(labeled_column))
output = output %>% inner_join(data, by = "rowname",na_matches="never")
output$is_ok = output$predict_label == output[,c(labeled_column)]
confusion_matrix <- confusionMatrix(output$predict_label, output[,c(labeled_column)
print(confusion_matrix)
return(output)
}
@bubbobne
Copy link
Author

  • filter(rank(-value) == 1) not get a value where there is several maximum value

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment