Skip to content

Instantly share code, notes, and snippets.

@trinker
Last active July 28, 2021 09:25
Show Gist options
  • Save trinker/594bd132b180a43945f7 to your computer and use it in GitHub Desktop.
Save trinker/594bd132b180a43945f7 to your computer and use it in GitHub Desktop.
Find the optimal number of topics in a topic model using the harmonic mean of the log likelihood
#' Find Optimal Number of Topics
#'
#' Iteratively produces models and then compares of the harmonic mean of the log likelihoods in a graphical output.
#'
#' @param x A \code{\link[tm]{DocumentTermMatrix}}.
#' @param max.k Maximum number of topics to fit (start small [i.e., default of 30] and add as necessary).
#' @param burnin Object of class \code{"integer"}; number of omitted Gibbs iterations at beginning, by default equals 0.
#' @param iter Object of class \code{"integer"}; number of Gibbs iterations, by default equals 2000.
#' @param keep Object of class \code{"integer"}; if a positive integer, the log-likelihood is saved every keep iterations.
#' @param method The method to be used for fitting; currently \code{method = "VEM"} or \code{method= "Gibbs"} are supported.
#' @param \ldots Other arguments passed to \code{??LDAcontrol}.
#' @return Returns the \code{\link[base]{data.frame}} of k (nuber of topics) and the associated log likelihood.
#' @references \url{http://stackoverflow.com/a/21394092/1000343} \cr
#' Ponweiser, M. (2012). Latent Dirichlet Allocation in R (Diploma Thesis). Vienna University of
#' Economics and Business, Vienna. http://cran.r-project.org/web/packages/topicmodels/vignettes/topicmodels.pdf
#' @keywords k topicmodel
#' @export
#' @author Ben Marwick and Tyler Rinker <tyler.rinker@@gmail.com>.
#' @examples
#' ## Install/Load Tools & Data
#' if (!require("pacman")) install.packages("pacman")
#' pacman::p_load_gh("trinker/gofastr")
#' pacman::p_load(tm, topicmodels, dplyr, tidyr, devtools, LDAvis, ggplot2)
#'
#'
#' ## Source topicmodels2LDAvis function
#' devtools::source_url("https://gist.githubusercontent.com/trinker/477d7ae65ff6ca73cace/raw/79dbc9d64b17c3c8befde2436fdeb8ec2124b07b/topicmodels2LDAvis")
#'
#' data(presidential_debates_2012)
#'
#'
#' ## Generate Stopwords
#' stops <- c(
#' tm::stopwords("english"),
#' "governor", "president", "mister", "obama","romney"
#' ) %>%
#' gofastr::prep_stopwords()
#'
#'
#' ## Create the DocumentTermMatrix
#' doc_term_mat <- presidential_debates_2012 %>%
#' with(gofastr::q_dtm_stem(dialogue, paste(person, time, sep = "_"))) %>%
#' gofastr::remove_stopwords(stops) %>%
#' gofastr::filter_tf_idf() %>%
#' gofastr::filter_documents()
#'
#'
#' opti_k <- optimal_k(doc_term_mat)
#' opti_k
optimal_k <- function(x, max.k = 30, burnin = 1000, iter = 1000, keep = 50, method = "Gibbs", verbose = TRUE, ...){
if (max.k > 20) {
message("\nGrab a cup of coffee this is gonna take a while...\n")
flush.console()
}
tic <- Sys.time()
hm_many <- sapply(2:max.k, function(k){
fitted <- topicmodels::LDA(x, k = k, method = method, control = list(burnin = burnin, iter = iter, keep = keep) )
logLiks <- fitted@logLiks[-c(1:(burnin/keep))]
harmonicMean(logLiks)
})
out <- c(2:max.k)[which.max(hm_many)]
class(out) <- c("optimal_k", class(out))
attributes(out)[["k_dataframe"]] <- data.frame(
k = 2:max.k,
harmonic_mean = hm_many
)
if (isTRUE(verbose)) cat(sprintf("Optimal number of topics = %s\n",as.numeric(out)))
out
}
harmonicMean <- function(logLikelihoods, precision=2000L) {
llMed <- Rmpfr::median(logLikelihoods)
as.double(llMed - log(Rmpfr::mean(exp(-Rmpfr::mpfr(logLikelihoods, prec = precision) + llMed))))
}
#' Plots a plot.optimal_k Object
#'
#' Plots a plot.optimal_k object
#'
#' @param x A \code{optimal_k} object.
#' @param \ldots Ignored.
#' @method plot plot.optimal_k
#' @export
plot.optimal_k <- function(x, ...){
y <- attributes(x)[["k_dataframe"]]
y <- y[y[["k"]] == as.numeric(x), ]
ggplot2::ggplot(attributes(x)[["k_dataframe"]], ggplot2::aes_string(x="k", y="harmonic_mean")) +
ggplot2::xlab("Number of Topics") +
ggplot2::ylab("Harmonic Mean of Log Likelihood") +
geom_point(data=y, color="blue", fill=NA, size = 6, shape = 21) +
ggplot2::geom_line(size=1) +
ggplot2::theme_bw() +
ggplot2::theme(
axis.title.x = ggplot2::element_text(vjust = -0.25, size = 14),
axis.title.y = ggplot2::element_text(size = 14, angle=90)
)
}
#' Prints a optimal_k Object
#'
#' Prints a optimal_k object
#'
#' @param x A \code{optimal_k} object.
#' @param \ldots Ignored.
#' @method print optimal_k
#' @export
print.optimal_k <- function(x, ...){
print(graphics::plot(x))
}
@skyocean
Copy link

Hi I have a question. If the method is 'VEM', how do you calculate the harmonic mean? Because the iteration is different. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment