Skip to content

Instantly share code, notes, and snippets.

Created January 2, 2015 07:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save anonymous/06e060cdf53257fde1ef to your computer and use it in GitHub Desktop.
Save anonymous/06e060cdf53257fde1ef to your computer and use it in GitHub Desktop.
Word2vec accuracy
def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar, lowercase=True):
"""
Compute accuracy of the model. `questions` is a filename where lines are
4-tuples of words, split into sections by ": SECTION NAME" lines.
See https://code.google.com/p/word2vec/source/browse/trunk/questions-words.txt for an example.
The accuracy is reported (=printed to log and returned as a list) for each
section separately, plus there's one aggregate summary at the end.
Use `restrict_vocab` to ignore all questions containing a word whose frequency
is not in the top-N most frequent words (default top 30,000).
This method corresponds to the `compute-accuracy` script of the original C word2vec.
"""
ok_vocab = dict(sorted(iteritems(self.vocab),
key=lambda item: -item[1].count)[:restrict_vocab])
ok_index = set(v.index for v in itervalues(ok_vocab))
sections, section = [], None
for line_no, line in enumerate(utils.smart_open(questions)):
# TODO: use level3 BLAS (=evaluate multiple questions at once), for speed
line = utils.to_unicode(line)
if line.startswith(': '):
# a new section starts => store the old section
if section:
sections.append(section)
self.log_accuracy(section)
section = {'section': line.lstrip(': ').strip(), 'correct': [], 'incorrect': []}
else:
if not section:
raise ValueError("missing section header before line #%i in %s" % (line_no, questions))
try:
if lowercase:
a, b, c, expected = [word.lower() for word in line.split()]
else:
a, b, c, expected = line.split()
except:
logger.info("skipping invalid line #%i in %s" % (line_no, questions))
if a not in ok_vocab or b not in ok_vocab or c not in ok_vocab or expected not in ok_vocab:
logger.debug("skipping line #%i with OOV words: %s" % (line_no, line.strip()))
continue
ignore = {a.lower(), b.lower(), c.lower()} # words to ignore
correct = False
expected = expected.lower()
# find the most likely prediction, ignoring OOV words and input words
for index in argsort(most_similar(self, positive=[b, c], negative=[a], topn=False))[::-1]:
predicted = self.index2word[index].lower()
if index in ok_index and predicted not in ignore:
if predicted != expected:
logger.debug("%s: expected %s, predicted %s" % (line.strip(), expected, predicted))
else:
correct = True
break
if correct:
section['correct'].append((a, b, c, expected))
else:
section['incorrect'].append((a, b, c, expected))
if section:
# store the last section, too
sections.append(section)
self.log_accuracy(section)
total = {
'section': 'total',
'correct': sum((s['correct'] for s in sections), []),
'incorrect': sum((s['incorrect'] for s in sections), []),
}
self.log_accuracy(total)
sections.append(total)
return sections
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment