Skip to content

Instantly share code, notes, and snippets.

@aneesh-joshi
Created July 9, 2018 14:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aneesh-joshi/76fac8fb6586bd45d3f6ee73fae79d23 to your computer and use it in GitHub Desktop.
Save aneesh-joshi/76fac8fb6586bd45d3f6ee73fae79d23 to your computer and use it in GitHub Desktop.
import sys
import os
sys.path.append(os.path.join('..'))
import csv
import re
import gensim.downloader as api
from gensim.utils import simple_preprocess
import numpy as np
class MyWikiIterable:
def __init__(self, fpath):
# self.type_translator = {'query': 0, 'doc': 1, 'label': 2}
# self.iter_type = iter_type
with open(fpath, encoding='utf8') as tsv_file:
tsv_reader = csv.reader(tsv_file, delimiter='\t', quotechar='"', quoting=csv.QUOTE_NONE)
self.data_rows = []
for row in tsv_reader:
self.data_rows.append(row)
self.to_print = ""
def preprocess_sent(self, sent):
"""Utility function to lower, strip and tokenize each sentence
Replace this function if you want to handle preprocessing differently"""
return re.sub("[^a-zA-Z0-9]", " ", sent.strip().lower()).split()
def get_stuff(self):
# Defining some consants for .tsv reading
QUESTION_ID_INDEX = 0
QUESTION_INDEX = 1
ANSWER_INDEX = 5
ANSWER_ID_INDEX = 4
LABEL_INDEX = 6
document_group = []
label_group = []
n_relevant_docs = 0
n_filtered_docs = 0
query_ids = []
query_id_group = []
doc_ids = []
doc_id_group = []
queries = []
docs = []
labels = []
for i, line in enumerate(self.data_rows[1:], start=1):
if i < len(self.data_rows) - 1: # check if out of bounds might occur
if self.data_rows[i][QUESTION_ID_INDEX] == self.data_rows[i + 1][QUESTION_ID_INDEX]:
document_group.append(self.preprocess_sent(self.data_rows[i][ANSWER_INDEX]))
doc_ids.append(self.data_rows[i][ANSWER_ID_INDEX])
label_group.append(int(self.data_rows[i][LABEL_INDEX]))
n_relevant_docs += int(self.data_rows[i][LABEL_INDEX])
else:
document_group.append(self.preprocess_sent(self.data_rows[i][ANSWER_INDEX]))
doc_ids.append(self.data_rows[i][ANSWER_ID_INDEX])
label_group.append(int(self.data_rows[i][LABEL_INDEX]))
n_relevant_docs += int(self.data_rows[i][LABEL_INDEX])
if n_relevant_docs > 0:
docs.append(document_group)
labels.append(label_group)
queries.append(self.preprocess_sent(self.data_rows[i][QUESTION_INDEX]))
query_ids.append(self.data_rows[i][QUESTION_ID_INDEX])
doc_id_group.append(doc_ids)
# yield [queries[-1], document_group, label_group, query_ids, doc_ids]
else:
n_filtered_docs += 1
n_relevant_docs = 0
document_group = []
label_group = []
doc_ids = []
else:
# If we are on the last line
document_group.append(self.preprocess_sent(self.data_rows[i][ANSWER_INDEX]))
label_group.append(int(self.data_rows[i][LABEL_INDEX]))
doc_ids.append(self.data_rows[i][ANSWER_ID_INDEX])
doc_id_group.append(doc_ids)
query_ids.append(self.data_rows[i][QUESTION_ID_INDEX])
n_relevant_docs += int(self.data_rows[i][LABEL_INDEX])
if n_relevant_docs > 0:
docs.append(document_group)
labels.append(label_group)
queries.append(self.preprocess_sent(self.data_rows[i][QUESTION_INDEX]))
# yield [queries[-1], document_group, label_group, query_ids, doc_ids]
else:
n_filtered_docs += 1
n_relevant_docs = 0
return queries, docs, labels, query_ids, doc_id_group
queries, doc_group, label_group, query_ids, doc_id_group = MyWikiIterable(os.path.join('..', 'experimental_data', 'WikiQACorpus', 'WikiQA-train.tsv')).get_stuff()
print(len(queries))
print(len(doc_group))
print(len(label_group))
print(len(query_ids))
print(len(doc_id_group))
# print(query_ids)
# print(queries)
# print(queries[0], '\n', docs[0],'\n', labels[0],'\n', query_ids[0],'\n', doc_ids[0])
# exit()
# for q, doc, labels, q_id, d_ids in data:
# for d, l, d_id in zip(doc, labels, d_ids):
def print_qrels(fname):
with open(fname, 'w') as f:
for q, doc, labels, q_id, d_ids in zip(queries, doc_group, label_group, query_ids, doc_id_group):
for d, l, d_id in zip(doc, labels, d_ids):
# print(q_id + '\t' + '0' + '\t' + str(d_id) + '\t' + str(l) + '\n')
f.write(q_id + '\t' + '0' + '\t' + str(d_id) + '\t' + str(l) + '\n')
print("QRELS done")
def print_my_pred(fname, similarity_fn):
del_kv_model = api.load('glove-wiki-gigaword-300')
dim_size = del_kv_model.vector_size
kv_model = del_kv_model.wv
del del_kv_model
def sent2vec(sent):
if len(sent)==0:
print('length is 0, Returning random')
return np.random.random((dim_size,))
vec = []
for word in sent:
if word in kv_model:
vec.append(kv_model[word])
if len(vec) == 0:
print('No words in vocab, Returning random')
return np.random.random((kv_model.vector_size,))
vec = np.array(vec)
return np.mean(vec, axis=0)
def cosine_similarity(vec1, vec2):
return np.dot(vec1, vec2)/(np.linalg.norm(vec1)* np.linalg.norm(vec2))
i=0
with open(fname, 'w') as f:
for q, doc, labels, q_id, d_ids in zip(queries, doc_group, label_group, query_ids, doc_id_group):
for d, l, d_id in zip(doc, labels, d_ids):
my_score = str(cosine_similarity(sent2vec(q),sent2vec(d)))
f.write(q_id + '\t' + 'Q0' + '\t' + str(d_id) + '\t' + '99' + '\t' + my_score + '\t' + 'STANDARD' + '\n')
print(i, "done")
i += 1
def w2v_similarity_fn(q, d):
return cosine_similarity(sent2vec(q),sent2vec(d))
print_qrels('my_test_qrels')
# import cProfile
print_my_pred('my_test_pred_w2v_300:', w2v_similarity_fn)
# print(sent2vec(simple_preprocess("asdasdasdasd helloasdasd"), kv_model))
# def print_my_pred(fname, similarity_fn):
# # del_kv_model = api.load('glove-wiki-gigaword-300')
# from drmm_tks import DRMM_TKS
# dtks_model = DRMM_TKS.load('new_hope_dtks_unk_zero_no_normalize_50topk_12ep')
# i=0
# with open(fname, 'w') as f:
# for q, doc, labels, q_id, d_ids in zip(queries, doc_group, label_group, query_ids, doc_id_group):
# for d, l, d_id in zip(doc, labels, d_ids):
# my_score = dtks_model.predict([q], [[d]])
# f.write(q_id + '\t' + 'Q0' + '\t' + str(d_id) + '\t' + '99' + '\t' + str(my_score[0][0]) + '\t' + 'STANDARD' + '\n')
# print(i, "done")
# i += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment