Skip to content

Instantly share code, notes, and snippets.

@jmclawson
Created May 4, 2023 13:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jmclawson/640042f2d679bcef1d20cf8056a66acd to your computer and use it in GitHub Desktop.
Save jmclawson/640042f2d679bcef1d20cf8056a66acd to your computer and use it in GitHub Desktop.
Functions for building a topic model and exploring it. Visualizations include document-level distributions (static and interactive), word distributions per topic, and topic word clouds.
library(wordcloud)
library(topicmodels)
library(plotly)
# Moves a table of texts through the necessary
# steps of preparation before building a topic
# model. The function applies these steps:
# 1. identifies text divisions by the `doc_id`
# column
# 2. divides each of the texts into same-sized
# chunks of `sample_size` words (default
# is 1000 words)
# 3. unnests text table into a table with one
# word per row
# 4. removes stop words and proper nouns
# (identified as any word that only appears
# with a capitalized first letter)
# 5. counts word frequencies for each chunk
# 6. converts the table of frequencies into a
# document term matrix
# 7. builds a topic model with `k` topics
make_topic_model <- function(
df,
doc_id = title,
sample_size = 1000,
k = 15) {
set_doc_samples <- function(
df,
size = 1000,
doc_id = title,
set_min = NULL,
collapse_cols = TRUE) {
df <- df |>
group_by({{doc_id}}) |>
mutate(set_id =
ceiling(row_number()/size)) |>
ungroup()
if (!is.null(set_min)) {
df <- df |>
group_by({{doc_id}}) |>
mutate(set_count = n()) |>
filter(set_count > set_min)
}
if(collapse_cols) {
df <- df |>
unite({{doc_id}}, {{doc_id}}, set_id)
}
return(df)
}
df |>
unnest_without_caps() |>
set_doc_samples(doc_id = {{doc_id}},
size = sample_size) |>
anti_join(get_stopwords()) |>
count({{doc_id}}, word, sort = TRUE) |>
rename(
document = {{doc_id}},
term = word,
value = n) |>
cast_dtm(document, term, value) |>
LDA(k = k,
method = "Gibbs",
control = list(best = TRUE,
initialize = "random"))
}
visualize_document_topics <- function(
lda,
top_n = 4,
direct_label = TRUE,
title = TRUE,
save = TRUE,
saveas = "png",
savedir = "plots",
omit = NULL,
smooth = TRUE) {
df_string <- deparse(substitute(lda))
plot_topic_parts <- function(df,
direct_label = TRUE) {
plot <- df |>
ggplot(aes(x = set, y = n))
if (direct_label) {
plot <- plot +
geom_area(aes(fill = as.factor(topic)),
show.legend = FALSE) +
geom_text(
data = df |>
group_by(doc_id) |>
filter(set == max(set)) |>
arrange(desc(topic)) |>
mutate(n = cumsum(n),
set = set + (3000 * (
row_number() - 1
))),
aes(x = set + 800,
label = topic,
color = as.factor(topic)),
show.legend = FALSE,
hjust = 0
)
} else {
plot <- plot +
geom_area(aes(fill = as.factor(topic)),
show.legend = TRUE)
}
plot +
facet_wrap(~ doc_id,
strip.position = "top",
ncol = 1,
labeller = labeller(groupwrap = label_wrap_gen(6))) +
scale_x_continuous(expand = expansion(c(0, 0.1)),
labels = scales::label_comma()) +
theme_minimal() +
labs(y = element_blank(),
x = "words",
fill = "topic") +
theme(plot.title.position = "plot",
strip.background = element_rect(fill = NA, color = NA),
strip.text = element_text(colour = "black",
hjust = 0),
panel.grid.major.x = element_blank(),
panel.grid.minor = element_blank()) +
scale_fill_viridis_d(option = "turbo") +
scale_color_viridis_d(option = "turbo") +
scale_y_continuous(
position = "right",
labels = scales::label_percent()) +
coord_cartesian(clip = 'off')
}
k <- attributes(lda)$k
plot <- lda |>
prep_document_topics(top_n = top_n,
omit = omit,
smooth = smooth) |>
plot_topic_parts(direct_label = direct_label)
if (title) {
plot <- plot + ggtitle(df_string)
}
if (save) {
ifelse(!dir.exists(file.path(savedir)),
dir.create(file.path(savedir)),
FALSE)
filename <- paste0(savedir, "/",
df_string,
" - document topics",
".",
saveas)
if (!saveas %in% c("pdf", "png")) {
ggsave(filename,
plot = plot,
dpi = 300,
bg = "white")
} else {
ggsave(filename, plot = plot, dpi = 300)
}
}
plot
}
prep_document_topics <- function(
lda,
top_n = 4,
omit = NULL,
smooth = FALSE){
df_string <- deparse(substitute(lda))
k <- attributes(lda)$k
doc_tops <- lda |>
tidy(matrix = "gamma") |>
separate(document,
c("title", "set"),
sep = "_") |>
mutate(set = as.integer(set),
ordered = TRUE) |>
group_by(title, topic) |>
mutate(topic_mean = mean(gamma,
na.rm = TRUE))
if (!is.null(omit)) {
doc_tops <- doc_tops |>
filter(!topic %in% omit)
}
doc_tops <- doc_tops |>
group_by(title) |>
mutate(topic_rank =
dense_rank(-topic_mean)) |>
ungroup() |>
# by default, n = 4 commonest topics
filter(topic_rank <= top_n) |>
# combine author and text
mutate(doc_id = title) |>
group_by(doc_id, set, topic) |>
summarise(n = sum(gamma, na.rm = TRUE), .groups = "keep") |>
mutate(percentage = n / sum(n),
set = set * 1000)
top_terms <-
lda |>
tidy() |>
group_by(topic) |>
arrange(desc(beta)) |>
slice_head(n = 10) |>
summarize(words = paste0(term, collapse = ", "))
result <- doc_tops |>
left_join(top_terms, by = "topic") |>
mutate(display = paste0("topic ", topic, ": ", words))
if(smooth) {
result <- result |>
ungroup() |>
group_by(topic) |>
arrange(set) |>
# rolling average across three sets
mutate(n2 = lead(n),
n3 = lead(n, n=2L)) |>
ungroup() |>
rowwise() |>
mutate(n_smooth = mean(c(n, n2, n3), na.rm = TRUE),
.after = n3) |>
select(-n, -n2, -n3) |>
rename(n = n_smooth)
}
result
}
interactive_document_topics <- function(
df,
top_n = 4,
title = FALSE,
height = NULL,
omit = NULL,
smooth = TRUE) {
df_string <- deparse(substitute(df))
plot <- df |>
prep_document_topics(top_n, omit = omit, smooth = smooth) |>
mutate(
# Shorten title to first word, dropping articles and prepositions
doc_id = doc_id |>
str_remove_all("^The\\b|^A\\b|^In the\\b|^In\\b|^To\\b") |>
str_remove_all("^ ") |>
strsplit(split = " ") |>
sapply(`[`, 1) |>
str_extract("[A-Za-z]+"),
topic = as.factor(topic)) |>
ggplot(aes(x = set, y = n)) +
geom_area(aes(fill = topic,
color = topic,
text = display),
show.legend = FALSE) +
facet_grid(doc_id ~ .,
labeller = labeller(groupwrap = label_wrap_gen(6))) +
scale_x_continuous(expand = expansion(c(0, 0.1)),
labels = scales::label_comma()) +
theme_minimal() +
labs(y = element_blank(),
x = "words",
fill = "topic") +
theme(
plot.title.position = "plot",
strip.background = element_rect(fill = "white", color = "white"),
strip.text = element_text(colour = "black",
hjust = 0),
panel.grid.major.x = element_blank(),
panel.grid.minor = element_blank()) +
scale_fill_viridis_d(alpha = 0.8) +
scale_color_viridis_d(alpha = 1) +
scale_y_continuous(
labels = scales::label_percent())
if (title) {
plot <- plot + ggtitle(df_string)
}
plot |>
ggplotly(tooltip = "text", height = height) |>
hide_legend() |>
suppressWarnings()
}
visualize_topic_bars <- function(
df,
topics,
top_n = 10,
expand_bars = TRUE,
save = TRUE,
savedir = "plots") {
df_string <- deparse(substitute(df))
plot <- tidy(df) |>
filter(topic %in% topics) |>
mutate(topic = paste("topic", topic) |>
factor(levels = paste("topic", topics))) |>
group_by(topic) |>
arrange(desc(beta)) |>
slice_head(n=top_n) |>
ungroup() |>
ggplot(aes(y = reorder_within(term, beta, topic),
x = beta)) +
geom_col(aes(fill = topic),
show.legend = FALSE) +
scale_y_reordered() +
labs(y = NULL,
x = NULL,
title = df_string) +
theme_minimal() +
theme(axis.text.x = element_blank(),
panel.grid = element_blank())
if(expand_bars) {
plot <- plot +
facet_wrap(~ topic, scales = "free")
} else {
plot <- plot +
facet_wrap(~ topic, scales = "free_y")
}
if (save) {
ifelse(!dir.exists(file.path(savedir)),
dir.create(file.path(savedir)),
FALSE)
filename <- paste0(savedir,"/",
df_string,
" - topics ",
paste0(topics, collapse=", "),
".png")
ggsave(filename, plot=plot)
}
plot
}
visualize_topic_wordcloud <- function(
df,
topics = NULL,
crop = TRUE,
savedir = "plots") {
save_topic_wordcloud <- function(
df,
topics = NULL,
dir = "plots",
count = 150,
df_string = NULL){
if(is.null(df_string)) {
df_string <- deparse(substitute(df))
cat("df_string was null!")
}
df <- tidy(df)
if(!is.null(topics)) {
df <- df |> filter(topic %in% topics)
}
ifelse(!dir.exists(file.path(dir)),
dir.create(file.path(dir)),
FALSE)
for(t in unique(df$topic)){
filename <- paste0(dir,"/", df_string, " - topic ", t, ".png")
png(filename, width = 12,
height = 8, units = "in",
res = 300)
wordcloud(words = df |>
filter(topic == t) |>
pull(term),
freq = df |>
filter(topic == t) |>
pull(beta),
max.words = count,
random.order = FALSE,
scale=c(3, .3),
rot.per = 0.2,
colors=viridis::turbo(
n=9,
direction =-1)[1:8])
dev.off()
}
}
df_string <- deparse(substitute(df))
save_topic_wordcloud(df, topics, df_string = df_string)
if (!is.null(topics)) {
paths <- paste0(savedir,"/", df_string, " - topic ", topics, ".png")
} else {
paths <- list.files(savedir, pattern = paste0(df_string, " - topic "),
full.names = TRUE)
}
if (crop) {
knitr::include_graphics(paths[1]) |>
knitr::plot_crop()
} else {
knitr::include_graphics(paths)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment