Word2vec accuracy
def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar, lowercase=True): | |
""" | |
Compute accuracy of the model. `questions` is a filename where lines are | |
4-tuples of words, split into sections by ": SECTION NAME" lines. | |
See https://code.google.com/p/word2vec/source/browse/trunk/questions-words.txt for an example. | |
The accuracy is reported (=printed to log and returned as a list) for each | |
section separately, plus there's one aggregate summary at the end. | |
Use `restrict_vocab` to ignore all questions containing a word whose frequency | |
is not in the top-N most frequent words (default top 30,000). | |
This method corresponds to the `compute-accuracy` script of the original C word2vec. | |
""" | |
ok_vocab = dict(sorted(iteritems(self.vocab), | |
key=lambda item: -item[1].count)[:restrict_vocab]) | |
ok_index = set(v.index for v in itervalues(ok_vocab)) | |
sections, section = [], None | |
for line_no, line in enumerate(utils.smart_open(questions)): | |
# TODO: use level3 BLAS (=evaluate multiple questions at once), for speed | |
line = utils.to_unicode(line) | |
if line.startswith(': '): | |
# a new section starts => store the old section | |
if section: | |
sections.append(section) | |
self.log_accuracy(section) | |
section = {'section': line.lstrip(': ').strip(), 'correct': [], 'incorrect': []} | |
else: | |
if not section: | |
raise ValueError("missing section header before line #%i in %s" % (line_no, questions)) | |
try: | |
if lowercase: | |
a, b, c, expected = [word.lower() for word in line.split()] | |
else: | |
a, b, c, expected = line.split() | |
except: | |
logger.info("skipping invalid line #%i in %s" % (line_no, questions)) | |
if a not in ok_vocab or b not in ok_vocab or c not in ok_vocab or expected not in ok_vocab: | |
logger.debug("skipping line #%i with OOV words: %s" % (line_no, line.strip())) | |
continue | |
ignore = {a.lower(), b.lower(), c.lower()} # words to ignore | |
correct = False | |
expected = expected.lower() | |
# find the most likely prediction, ignoring OOV words and input words | |
for index in argsort(most_similar(self, positive=[b, c], negative=[a], topn=False))[::-1]: | |
predicted = self.index2word[index].lower() | |
if index in ok_index and predicted not in ignore: | |
if predicted != expected: | |
logger.debug("%s: expected %s, predicted %s" % (line.strip(), expected, predicted)) | |
else: | |
correct = True | |
break | |
if correct: | |
section['correct'].append((a, b, c, expected)) | |
else: | |
section['incorrect'].append((a, b, c, expected)) | |
if section: | |
# store the last section, too | |
sections.append(section) | |
self.log_accuracy(section) | |
total = { | |
'section': 'total', | |
'correct': sum((s['correct'] for s in sections), []), | |
'incorrect': sum((s['incorrect'] for s in sections), []), | |
} | |
self.log_accuracy(total) | |
sections.append(total) | |
return sections |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment