Last active
March 27, 2018 10:24
-
-
Save Torvaney/ee9eb6de8a5522647e57add2988a7c91 to your computer and use it in GitHub Desktop.
Compare methods of finding skipgram windows tidily
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(tidyverse) | |
library(tidytext) | |
# Load data ---- | |
# Slightly different to https://juliasilge.com/blog/word-vectors-take-two/ | |
# just because I have this data locally | |
austen_text <- janeaustenr::northangerabbey %>% | |
as_tibble() %>% | |
rename(text = value) %>% | |
mutate(text = str_replace_all(text, "'|"|/", "'"), ## weird encoding | |
text = str_replace_all(text, "<a(.*?)>", " "), ## links | |
text = str_replace_all(text, ">|<|&", " "), ## html yuck | |
text = str_replace_all(text, "&#[:digit:]+;", " "), ## html yuck | |
text = str_replace_all(text, "<[^>]*>", " "), ## mmmmm, more html yuck | |
text = str_to_lower(text), ## BT EDIT | |
postID = row_number()) %>% ## Actually more like line ID... | |
filter(text != "") | |
# Dr Silge ---- | |
# From https://juliasilge.com/blog/tidy-word-vectors/ | |
make_windows_js <- function(tbl, doc_var, window_size) { | |
tbl %>% | |
unnest_tokens(ngram, !!doc_var, token = "ngrams", n = window_size) %>% | |
mutate(window_id = row_number()) %>% # Rename for consistency with other methods | |
unite(skipgramID, postID, window_id, remove = FALSE) %>% # Added remove = F for comparison with other methods | |
unnest_tokens(word, ngram) | |
} | |
# From https://juliasilge.com/blog/word-vectors-take-two/ | |
slide_windows_js <- function(tbl, doc_var, window_size) { | |
# each word gets a skipgram (window_size words) starting on the first | |
# e.g. skipgram 1 starts on word 1, skipgram 2 starts on word 2 | |
each_total <- tbl %>% | |
group_by(!!doc_var) %>% | |
mutate(doc_total = n(), | |
each_total = pmin(doc_total, window_size, na.rm = TRUE)) %>% | |
pull(each_total) | |
rle_each <- rle(each_total) | |
counts <- rle_each[["lengths"]] | |
counts[rle_each$values != window_size] <- 1 | |
# each word get a skipgram window, starting on the first | |
# account for documents shorter than window | |
id_counts <- rep(rle_each$values, counts) | |
window_id <- rep(seq_along(id_counts), id_counts) | |
# within each skipgram, there are window_size many offsets | |
indexer <- (seq_along(rle_each[["values"]]) - 1) %>% | |
map2(rle_each[["values"]] - 1, | |
~ seq.int(.x, .x + .y)) %>% | |
map2(counts, ~ rep(.x, .y)) %>% | |
flatten_int() + | |
window_id | |
tbl[indexer, ] %>% | |
bind_cols(data_frame(window_id)) %>% | |
group_by(window_id) %>% | |
filter(n_distinct(!!doc_var) == 1) %>% | |
ungroup | |
} | |
# Jason Punyon ---- | |
# From https://gist.github.com/JasonPunyon/3bca3bf606e7583c7ea2d8a00f86418e | |
slide_windows_jp <- function(tbl, doc_var, window_size) { | |
tbl %>% | |
group_by(!!doc_var) %>% | |
mutate(WordId = row_number() - 1, | |
RowCount = n()) %>% | |
ungroup() %>% | |
crossing(InWindowIndex = 0:(window_size-1)) %>% | |
filter((WordId - InWindowIndex) >= 0, # starting position of a window must be after the beginning of the document | |
(WordId - InWindowIndex + window_size - 1) < RowCount # ending position of a window must be before the end of the document | |
) %>% | |
mutate(window_id = WordId - InWindowIndex + 1) | |
} | |
# Me ---- | |
lag_words <- function(tbl, col, offset) { | |
# Adds a new column with a lagged output and appropriate | |
# column name | |
colname <- paste0(quo_name(col), offset) | |
tbl %>% | |
mutate(!!colname := lag(!!col, offset)) | |
} | |
create_window_wide <- function(tbl, col, window_size) { | |
# Adds columns containing previous words to create a window (in wide form) | |
max_offset <- window_size - 1 | |
# Create a list of functions (using map) and apply iteratively | |
# over tbl | |
map(1:max_offset, .f = ~ function(t) lag_words(t, col, .x)) %>% | |
reduce(~ .y(.x), .init = tbl) | |
} | |
slide_windows_bt <- function(tbl, word_var, doc_var, window_size) { | |
tbl %>% | |
# Add a marker to remove windows smaller than `window_size` later on | |
group_by(!!doc_var) %>% | |
mutate(word_position = row_number()) %>% | |
# Add lagged columns to tbl (still grouped) | |
create_window_wide(word_var, window_size) %>% | |
ungroup() %>% | |
mutate(window_id = row_number()) %>% | |
# Remove small windows | |
filter(word_position >= window_size) %>% | |
# Make tidy (wide -> long tbl) | |
gather(key = position, value = word, -(!!doc_var), -window_id, -word_position) %>% | |
# Remove unnecessary columns | |
select(-position, -word_position) %>% | |
arrange(window_id) %>% | |
# Undo reverse ordering caused by `gather` (make words read top to bottom) | |
group_by(window_id) %>% | |
arrange(desc(row_number())) | |
} | |
# Create windows using each method ---- | |
window_size <- 8 | |
unnest_and_filter <- . %>% | |
unnest_tokens(word, text) | |
windows_js1 <- austen_text %>% | |
make_windows_js(quo(text), window_size) | |
windows_js2 <- austen_text %>% | |
unnest_and_filter() %>% | |
slide_windows_js(quo(postID), window_size) | |
windows_jp <- austen_text %>% | |
unnest_and_filter() %>% | |
slide_windows_jp(quo(postID), window_size) | |
windows_bt <- austen_text %>% | |
unnest_and_filter() %>% | |
slide_windows_bt(quo(word), quo(postID), window_size) | |
# Compare methods' windows ---- | |
count_windows <- . %>% | |
group_by(postID) %>% | |
summarise(n = length(unique(window_id))) | |
add_label <- function(tbl, suffix) { | |
newcol <- paste("n", suffix, sep = "_") | |
# TODO: change to rename + NSE magic instead... | |
tbl %>% | |
mutate(!!newcol := n) %>% | |
select(-n) | |
} | |
comparison <- list(js1 = windows_js1, | |
js2 = windows_js2, | |
jp = windows_jp, | |
bt = windows_bt) %>% | |
map(count_windows) %>% | |
map2(., names(.), add_label) %>% | |
reduce(full_join, by = c("postID")) %>% | |
mutate(match1 = (n_bt == n_js1), # TODO find a more generic way to do this | |
match2 = (n_bt == n_js2), | |
match3 = (n_bt == n_jp), | |
match = match1 & match2 & match3) | |
# View lines which have a different number of windows | |
# between the two methods | |
unmatched_id <- comparison %>% | |
# filter(n_js < n_bt) %>% | |
filter(!match) %>% | |
.$postID %>% | |
first() | |
austen_text %>% | |
filter(postID == unmatched_id) | |
windows_js1 %>% | |
filter(postID == unmatched_id) | |
windows_js2 %>% | |
filter(postID == unmatched_id) | |
windows_bt %>% | |
filter(postID == unmatched_id) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment