Created
December 12, 2017 08:56
-
-
Save andrewyates/138df7a63ac267fea1894d621bcef765 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def avg_vec(terms): | |
vecs = [word2vec_model[term] for term in terms if term in word2vec_model] | |
if len(vecs) == 0: | |
print("WARNING: sequence of UNKs") | |
return np.zeros(word2vec_model.vector_size) | |
else: | |
return unitvec(np.mean(vecs, axis=0)) | |
os.mkdir(context_outdir) | |
qid_contexts = {} | |
for qid in sorted(qrel.keys()): | |
terms = data['topic'][qid] + data['desc'][qid] | |
qid_contexts[qid] = avg_vec(terms) | |
win = 4 | |
for qid in qid_contexts: | |
sims = {} | |
for cwid in qid_cwid_label[qid]: | |
if cwid not in cwid_docs or len(cwid_docs[cwid]) == 0: | |
print("WARNING: missing doc %s for qid %s with label %s" % (cwid, qid, qid_cwid_label[qid][cwid])) | |
continue | |
doc = cwid_docs[cwid] | |
sims[cwid] = [] | |
for i in range(len(doc)): | |
begin = i - win | |
if begin < 0: | |
begin = 0 | |
end = 1 + i + win | |
if end > len(doc): | |
end = len(doc) | |
# this depends on how you preprocessed the documents; enable it if needed | |
#if doc[i] not in w2v: | |
# continue | |
doc_context = avg_vec(doc[begin:end]) | |
sims[cwid].append(np.dot(qid_contexts[qid], doc_context)) | |
sims[cwid] = np.array(sims[cwid]) | |
pickle.dump(sims, open(context_outdir + '/%s.p' % qid, 'wb'), protocol=-1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment