Skip to content

Instantly share code, notes, and snippets.

@abelsonlive
Created December 6, 2012 17:55
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save abelsonlive/4226539 to your computer and use it in GitHub Desktop.
Save abelsonlive/4226539 to your computer and use it in GitHub Desktop.
topic modeling in R
# Brian Abelson @brianabelson
# Harmony Institute
# December 5, 2012
# lda is a wrapper for lda.collapsed.gibbs.sampler in the "lda" package
# it fits topic models using latent dirichlet allocation
# it provides arguments for cleaning the input text and tuning the parameters of the model
# it also returns alot of useful information about the topics/documents in a format that you can easily join back to your original data
# this allows you to easily model outcomes based on the distribution of topics within a collection of texts
lda <- function(
# DATA #
text, # a character vector of text documents
ids = NULL, # a vector of ids (to allow joining results to other variables). default is 1:N
# CLEANING #
lower_case = TRUE, # logical; should the function make the text lower case?
remove_stop_words = TRUE, # logical; should the function remove stop words? NOTE: this will also make the text lower case
stop_words_to_add = NULL, # a character vector of stopwords to add
remove_numbers = TRUE, # logical; should the function remove numbers?
remove_punctuation = TRUE, # logical; should the function remove punctuation?
remove_non_ascii = TRUE, # logical; should the function remove non-ASCII characters?
stem_words = FALSE, # logical; should the function stem the words?
char_range = c(2,50), # numeric vector of length two with low and high value of characters per word (inclusive!)
min_word_count = 5, # number of times a word/feature must occur in a text to be considered
# MODEL PARAMETERS #
n_topics = 10, # number of topics to fit
n_topic_words = 20, # number of top topic words to return
n_iter = 1000, # number of iterations
burnin = 100, # number of initial iterations to ignore. the function adds burnin to n_iter
alpha = 0.1, # the scalar value of the dirichlet hyperparameter for topic proportions
eta = 0.1, # the scalar value of the dirichlet hyperparamater for topic multinomials
# OUTPUT #
n_assignments = 3 # number of assigments to return (returned as ass_topic_a, ass_topic_b, ass_topic_c, etc.)
) {
# LIBRARIES
if(!require("tm")) {
install.packages("tm")
library("tm")
}
if(!require("lda")) {
install.packages("lda")
library("lda")
}
if(!require("plyr")) {
install.packages("plyr")
library("plyr")
}
if(!require("stringr")) {
install.packages("stringr")
library("stringr")
}
if(!require("Rstem")) {
install.packages("Rstem", repos="http://www.omegahat.org/R", type="source")
library("Rstem")
}
# start time (for calculating the time it takes for function to run)
start <- Sys.time()
# gen id var if NULL
if(is.null(ids)) {
ids <- 1:length(text)
}
url_pattern = '\b(?:(?:https?|ftp|file)://|www\\.|ftp\\.)[-A-Z0-9+&@#/%=~_|$?!:,.]*[A-Z0-9+&@#/%=~_|$]'
gsub(pattern)
# META VARIABLES - RAW TEXT
# total number of characters/ features / unique features
docStats <- function(x) {
# length of document
len <- nchar(x)
# split words
words <- str_trim(unlist(strsplit(x, " ")))
words <- words[words!=""]
# calculate average word length
nchars <- laply(words, nchar)
len_word <- mean(nchars)
# count features
n_feat <- length(words)
n_unq_feat <- length(unique(words))
# return stats
return(data.frame(len, len_word, n_feat, n_unq_feat))
}
features_raw <- ldply(text, docStats)
names(features_raw) <- paste0(names(features_raw),"_raw")
# CLEAN THE INPUT TEXT #
# convert text to corpus
corpus <- Corpus(VectorSource(text))
# standardize case
if (lower_case) {
corpus <- tm_map(corpus, tolower)
}
# remove stopwords / numbers / punctuation / whitespace
if (remove_stop_words) {
corpus <- tm_map(corpus, tolower)
print("removing stop words...")
stop_words <- c(stopwords('english'), stop_words_to_add)
corpus <- tm_map(corpus, removeWords, stop_words)
}
# remove numbers / punctuation / strip whitespace
print("cleaning text...")
if (remove_numbers) {
corpus <- tm_map(corpus, removeNumbers)
}
if (remove_punctuation) {
removePunct <- function(x) {
gsub("[[:punct:]]", " ", x)
}
corpus <- tm_map(corpus, removePunct)
}
# remove non-ASCII characters
if (remove_non_ascii) {
removeNonASCII <- function(x) {
iconv(x, "latin1", "ASCII", sub="")
}
corpus <- tm_map(corpus, removeNonASCII)
}
corpus <- tm_map(corpus, stripWhitespace)
# filter out words that have characters longer than 255 - these will break the stemming function
charFilter <- function(x) {
words <- str_trim(unlist(strsplit(x, " ")))
#ensure all empty words and words with more than 50 characters are removed
nchars <- laply(words, nchar)
clean_words <- words[which(nchars <= 255)]
output <- paste(clean_words, collapse=" ")
return(output)
}
corpus <- tm_map(corpus, charFilter)
corpus <- tm_map(corpus, stripWhitespace)
# stem words
if(stem_words) {
print("stemming words...")
# generate stemming function
wordStemmer <- function(x) {
words <- str_trim(unlist(strsplit(x, " ")))
words <- words[words!=""]
# stem words
stemmed_words <- wordStem(words)
# collapse back into one blob
output <- paste(stemmed_words, collapse=" ")
return(output)
}
# run stemming function
corpus <- tm_map(corpus, wordStemmer)
}
# filter out words that fall outside of desired char_range
charFilter2 <- function(x) {
words <- str_trim(unlist(strsplit(x, " ")))
nchars <- laply(words, nchar)
clean_words <- words[which(nchars >= char_range[1] & nchars <= char_range[2])]
output <- str_trim(paste(clean_words, collapse=" "))
return(output)
}
corpus <- tm_map(corpus, charFilter2)
# strip white space again for good measure
corpus <- tm_map(corpus, stripWhitespace)
# convert corpus back to character vector for lexicalizing
text <- as.character(corpus)
# META VARIABLES - CLEAN TEXT
# total number of characters / features / unique features
features_clean <- ldply(text, docStats)
names(features_clean) <- paste0(names(features_clean),"_clean")
# CREATE / FILTER LEXICON
# lexicalize text
print("lexicalizing text...")
corpus <- lexicalize(text, sep=" ", count=1)
# only keep words that appear at least twice.
N <- min_word_count
keep <- corpus$vocab[word.counts(corpus$documents, corpus$vocab) >= N]
# re-lexicalize, using this subsetted vocabulary
documents <- lexicalize(text, lower=TRUE, vocab=keep)
# FIT TOPICS
# gibbs sampling
# K is the number of topics
print("fitting topics...")
K <- n_topics
n_iter <- n_iter + burnin
result <- lda.collapsed.gibbs.sampler(documents, K, keep, n_iter, alpha, eta)
# PREPARE OUTPUT
print("preparing output...")
# top words by document
predictions <- t(predictive.distribution(result$document_sums, result$topics, 0.1, 0.1))
document_words <- data.frame(top.topic.words(predictions, n_topic_words, by.score = TRUE))
names(document_words) <- ids
# top words by topic
topic_words <- data.frame(top.topic.words(result$topics, num.words = n_topic_words, by.score = TRUE))
names(topic_words) <- paste0("topic_", 1:K)
# topics by documents stats
raw <- as.data.frame(t(result$document_sums))
names(raw) <- 1:K
n_docs <- nrow(raw)
topics <- data.frame(id = ids, matrix(0, nrow = n_docs, ncol=2*K))
names(topics) <- c("id", paste0("n_topic_", 1:K), paste0("p_topic_", 1:K))
# add assignment variables dynamically
topic_ass_vars <- paste0("ass_topic_", letters[1:n_assignments])
topics[,topic_ass_vars] <- 0
# assign primary and secondary topic(s), get distribution topics by document
for(doc in 1:n_docs) {
assignments <- as.numeric(names(sort(raw[doc,1:K], decreasing=TRUE)))
topics[doc, topic_ass_vars] <- assignments[1:n_assignments]
topics[doc, grep("n_topic_[0-9]+", names(topics))] <- raw[doc,]
topics[doc, grep("p_topic_[0-9]+", names(topics))] <- raw[doc,] / sum(raw[doc,])
}
# add meta variables
document_stats <- data.frame(topics, features_raw, features_clean)
# CALCULATE JOB LENGTH
end <- Sys.time()
job_length <- round(difftime(end, start, units="mins"), digits=2)
print(paste("lda finished at:", end))
print(paste("job took:", job_length, "minutes"))
# RETURN OUTPUT
return(list(topic_words = topic_words,
document_stats = document_stats,
document_words = document_words))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment