Skip to content

Instantly share code, notes, and snippets.

@nathankjer
Created January 19, 2023 04:20
Show Gist options
  • Save nathankjer/8f4c0d8c3f527efa1331920d97c09ae6 to your computer and use it in GitHub Desktop.
Save nathankjer/8f4c0d8c3f527efa1331920d97c09ae6 to your computer and use it in GitHub Desktop.
Code to implement pathfinding between two latent word vectors.
import heapq
class elem_t:
def __init__(self,value,parent=None,cost=None):
self.value = value
self.cost = cost
self.column_offset = 0
self.parent = parent
class PriorityQueue:
def __init__(self):
self._queue = []
self._index = 0
def push(self,item):
heapq.heappush(self._queue, (item.cost,self._index,item) )
self._index += 1
def pop(self):
index,item = heapq.heappop(self._queue)[1:]
return item
def length(self):
return len(self._queue)
def get_transition_cost(word1,word2,doc2vec_model):
return 1.0-float(doc2vec_model.similarity(word1,word2))
def a_star_search(start_word, end_word, doc2vec_model, branching_factor=60, weight=4.):
cost_list = {start_word:0}
frontier = PriorityQueue()
start_elem = elem_t(start_word,parent=None,cost=get_transition_cost(start_word,end_word,doc2vec_model))
frontier.push(start_elem)
path_end = start_elem
explored = []
while True:
if frontier.length() == 0:
break
current_node = frontier.pop()
current_word = current_node.value
explored.append(current_word)
if current_word == end_word:
path_end = current_node
break
neighbors = [x[0] for x in doc2vec_model.most_similar(current_word,topn=branching_factor) if x!=current_word]
if neighbors == None:
continue
base_cost = cost_list[current_word]
for neighbor_word in neighbors:
if current_word == neighbor_word:
continue
cost = base_cost + get_transition_cost(current_word,neighbor_word,doc2vec_model)
new_elem = elem_t(neighbor_word,parent=current_node,cost=cost)
new_elem.column_offset = neighbors.index(neighbor_word)
if (neighbor_word not in cost_list or cost<cost_list[neighbor_word]) and neighbor_word not in explored:
cost_list[neighbor_word] = cost
new_elem.cost = cost + weight*get_transition_cost(neighbor_word,end_word,doc2vec_model)
frontier.push(new_elem)
print("Explored: "+str(len(explored))+", Frontier: "+str(frontier.length())+", Cost: "+str(base_cost)[:5],end='\r')
print('')
path = [path_end.value]
cur = path_end
while cur:
cur = cur.parent
if cur:
path.append(cur.value)
return path[::-1]
name = 'reviews'
vocab_size = 12000
epochs = 100
if not os.path.isfile('{0}.model'.format(name)):
spm.SentencePieceTrainer.Train('--input={0}.txt --model_prefix={0} --vocab_size={1} --split_by_whitespace=True'.format(name,vocab_size))
sp = spm.SentencePieceProcessor()
sp.load('{0}.model'.format(name))
if not os.path.isfile('./checkpoints/{0}-{1}.w2v'.format(name,epochs-1)):
if not os.path.isfile('{0}_tokenized.tsv'.format(name)):
with open('{0}_tokenized.tsv'.format(name),'w+') as f:
for i,line in enumerate(open('{0}.txt'.format(name))):
ids = sp.EncodeAsIds(line)
if len(ids) > 10:
f.write('{0}\t{1}\n'.format(i,' '.join(str(x) for x in ids)))
model = Doc2Vec(vector_size=300, window=8, min_count=3, negative=5, workers=16, epochs=1)
documents = get_documents('{0}_tokenized.tsv'.format(name),name)
model.build_vocab(documents)
model = train_model(model,documents,name,epochs)
model = Doc2Vec.load('./checkpoints/{0}-{1}.w2v'.format(name,epochs-1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment