Last active
May 22, 2018 21:48
-
-
Save aneesh-joshi/ac2c28707b9f3b7bc4a1c15a8ff657ba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gensim.similarity_learning import WikiQAExtractor | |
wikiqa = WikiQAExtractor(os.path.join("..", "data", "WikiQACorpus", "WikiQA-train.tsv")) | |
data = wikiqa.get_data() | |
# Below commented code is for making a dict for word vectors and pickling it | |
# w2v = {} | |
# with open('glove.6B.50d.txt') as f: | |
# for line in f: | |
# string_array = np.array(line.split()[1:]) | |
# string_array = [float(i) for i in string_array] | |
# w2v[line.split()[0]] = string_array | |
# with open('w2v.pkl', 'wb') as f: | |
# pickle.dump(w2v, f) | |
with open('w2v.pkl', 'rb') as f: | |
w2v = pickle.load(f) | |
def sent2vec(sentence): | |
vec_sum = [] | |
for word in sentence.split(): | |
if word in w2v: | |
vec_sum.append(w2v[word]) | |
return np.mean(np.array(vec_sum), axis=0) | |
def cos_sim(vec1, vec2): | |
return np.sum(vec1*vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2)) | |
big_y_true = [] | |
big_y_pred = [] | |
for doc in data: | |
y_true = [] | |
y_pred = [] | |
for query, doc, label in doc: | |
y_pred.append(cos_sim(sent2vec(query), sent2vec(doc))) | |
y_true.append(label) | |
big_y_true.append(y_true) | |
big_y_pred.append(y_pred) | |
big_y_pred = np.array(big_y_pred) | |
big_y_true = np.array(big_y_true) | |
n_correct = 0 | |
for y_pred, y_true in zip(big_y_pred, big_y_true): | |
if (np.argmax(y_true) == np.argmax(y_pred)): | |
n_correct += 1 | |
print("Accuracy : ", n_correct/len(big_y_true)) | |
# Accuracy : 0.2867264997638167 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment