Created
August 30, 2018 18:53
-
-
Save tmastny/065ec4964e043f5eafa0a690f004a81c to your computer and use it in GitHub Desktop.
Meaning chains with word embeddings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# http://sappingattention.blogspot.com/2018/06/meaning-chains-with-word-embeddings.html | |
cpath = function(matrix, w1, w2, blacklist = c()) { | |
p1 = matrix[[w1]] | |
p2 = matrix[[w2]] | |
base_similarity = cosineSimilarity(p1, p2) | |
pairsims = matrix %*% t(rbind(p1, p2)) | |
# Words closer to *either* candidate than the other word. These might be used eventually; mostly to save on computation in later steps. | |
decentsims = pairsims[apply(pairsims, 1, max) >= base_similarity[1],] | |
# Words closer to *both* candidates than the other word. The best of these is the intermediary. | |
greatsims = pairsims[apply(pairsims, 1, min) >= base_similarity[1],] | |
if (length(nrow(greatsims))==0 || nrow(greatsims) == 0) { | |
return(unique(c(w1, w2))) | |
} | |
pivot = order(-rowSums(greatsims))[1:4] | |
pivotwords = rownames(greatsims)[pivot] | |
pivotword = pivotwords[!pivotwords %in% c(w1,w2, blacklist)][1] | |
# highlight pivots. | |
message(w1, "<-->", toupper(pivotword), "<--->", w2) | |
if(is.null(pivotword) || is.na(pivotword)) { | |
message(w1,w2) | |
return(unique(c(w1, w2))) | |
} | |
pivotpoint = matrix[[pivotword]] | |
# Really, this should *different* for the right and left side. | |
mat = matrix[rownames(matrix) %in% rownames(decentsims),] | |
left = cpath(mat, w1, pivotword, blacklist = c(w1, w2, blacklist)) | |
right = cpath(mat, pivotword, w2, blacklist = c(w1, w2, blacklist)) | |
return(unique(c(left, right))) | |
} | |
drawpath = function(w1, w2, save=F,nnoise = 1) { | |
pathed = cpath(model, w1, w2) | |
just_this = model %>% extract_vectors(pathed) | |
r = just_this %>% prcomp %>% `$`("rotation") %>% `[`(,c(1,2)) | |
noisewords = sample(rownames(model), nnoise) | |
with_noise = model %>% extract_vectors(c(pathed, noisewords)) | |
lotsa = with_noise %*% r %>% as.data.frame %>% rownames_to_column("word") %>% | |
mutate(labelled = word %in% pathed) %>% | |
mutate(word = ordered(word, levels = c(pathed, noisewords) )) %>% arrange(word) | |
g = lotsa %>% | |
ggplot() + aes(x=PC1,y = PC2, label = word) + geom_point(data = lotsa %>% filter(!labelled), size = .5, alpha = .33) + geom_path(data = lotsa %>% filter(labelled), alpha = .33) + geom_text(data = lotsa %>% filter(labelled)) + labs(title=paste("From", w1, "to", w2)) | |
if (save) | |
{ggsave(g, width = 10, height = 8, filename = paste0("~/Pictures/", paste("From", w1, "to", w2), ".png"))} | |
g | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment