Created
October 4, 2015 20:22
-
-
Save czinn/fc9ace1358a33dd33ffb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# You can find a pretrained binary vector files or train your own using tools found at | |
# https://code.google.com/p/word2vec/ | |
# Sometimes it takes a while to find a path, but #fuckitshipit | |
# You can speed it up by decreasing the number of words loaded | |
# If you don't like your scale, try reversing the two words | |
MAX_WORDS_TO_LOAD = 50000 | |
import sys, struct, heapq | |
from queue import PriorityQueue | |
def readUntil(f, stop): | |
s = "" | |
while True: | |
b = f.read(1) | |
if b == stop: | |
break | |
try: | |
s += b.decode("utf-8") | |
except UnicodeDecodeError: | |
pass | |
return s | |
def dotProduct(vecA, vecB): | |
return sum(a * b for a, b in zip(vecA, vecB)) | |
def differentEnough(a, b): | |
matches = sum(x == y for x, y in zip(a.lower(), b.lower())) | |
return not(matches * 5 > min(len(a), len(b)) * 3) | |
def nearby(word, word_map, n=5): | |
vec = word_map[word] | |
return heapq.nlargest(n, ((dotProduct(vec, wvec) * differentEnough(word, w), w) for w, wvec in word_map.items())) | |
def pathfind(start, end, wm): | |
q = PriorityQueue() | |
q.put((-dotProduct(wm[start], wm[end]), end)) | |
src = {} | |
while not q.empty(): | |
d, w = q.get() | |
for myd, nw in nearby(w, word_map, n=10): | |
if nw not in src: | |
nd = -dotProduct(wm[start], wm[nw]) | |
src[nw] = w | |
if nw == start: | |
break | |
q.put((nd, nw)) | |
if nw == start: | |
break | |
path = [] | |
while start != end: | |
path.append(start) | |
start = src[start] | |
return path + [end] | |
if __name__ == "__main__": | |
if len(sys.argv) < 2: | |
print("Specify a binary vector file (word2vec)") | |
sys.exit(1) | |
f = open(sys.argv[1], "rb") | |
words, size = tuple(int(n) for n in readUntil(f, b"\n").split()) | |
words = min(MAX_WORDS_TO_LOAD, words) | |
word_map = {} | |
for i in range(words): | |
word = readUntil(f, b" ") | |
if i % (words // 100) == 0: | |
print(".", end="", flush=True) | |
v = [] | |
total = 0 | |
for j in range(size): | |
val = struct.unpack("f", f.read(4))[0] | |
total += val * val | |
v.append(val) | |
total **= 0.5 | |
for j in range(size): | |
v[j] /= total | |
word_map[word] = v | |
f.close() | |
print("\nLoaded {0} words".format(words)) | |
print() | |
while True: | |
print("Enter two words") | |
start, end = input().split() | |
if not start in word_map: | |
print("{0} not found".format(start)) | |
continue | |
if not end in word_map: | |
print("{0} not found".format(end)) | |
continue | |
print("Building your scale!") | |
path = pathfind(start, end, word_map) | |
for i, w in enumerate(path): | |
print("{0}.\t{1}".format(i + 1, w)) | |
print() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment