Created
January 2, 2015 07:57
-
-
Save anonymous/06e060cdf53257fde1ef to your computer and use it in GitHub Desktop.
Word2vec accuracy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar, lowercase=True): | |
""" | |
Compute accuracy of the model. `questions` is a filename where lines are | |
4-tuples of words, split into sections by ": SECTION NAME" lines. | |
See https://code.google.com/p/word2vec/source/browse/trunk/questions-words.txt for an example. | |
The accuracy is reported (=printed to log and returned as a list) for each | |
section separately, plus there's one aggregate summary at the end. | |
Use `restrict_vocab` to ignore all questions containing a word whose frequency | |
is not in the top-N most frequent words (default top 30,000). | |
This method corresponds to the `compute-accuracy` script of the original C word2vec. | |
""" | |
ok_vocab = dict(sorted(iteritems(self.vocab), | |
key=lambda item: -item[1].count)[:restrict_vocab]) | |
ok_index = set(v.index for v in itervalues(ok_vocab)) | |
sections, section = [], None | |
for line_no, line in enumerate(utils.smart_open(questions)): | |
# TODO: use level3 BLAS (=evaluate multiple questions at once), for speed | |
line = utils.to_unicode(line) | |
if line.startswith(': '): | |
# a new section starts => store the old section | |
if section: | |
sections.append(section) | |
self.log_accuracy(section) | |
section = {'section': line.lstrip(': ').strip(), 'correct': [], 'incorrect': []} | |
else: | |
if not section: | |
raise ValueError("missing section header before line #%i in %s" % (line_no, questions)) | |
try: | |
if lowercase: | |
a, b, c, expected = [word.lower() for word in line.split()] | |
else: | |
a, b, c, expected = line.split() | |
except: | |
logger.info("skipping invalid line #%i in %s" % (line_no, questions)) | |
if a not in ok_vocab or b not in ok_vocab or c not in ok_vocab or expected not in ok_vocab: | |
logger.debug("skipping line #%i with OOV words: %s" % (line_no, line.strip())) | |
continue | |
ignore = {a.lower(), b.lower(), c.lower()} # words to ignore | |
correct = False | |
expected = expected.lower() | |
# find the most likely prediction, ignoring OOV words and input words | |
for index in argsort(most_similar(self, positive=[b, c], negative=[a], topn=False))[::-1]: | |
predicted = self.index2word[index].lower() | |
if index in ok_index and predicted not in ignore: | |
if predicted != expected: | |
logger.debug("%s: expected %s, predicted %s" % (line.strip(), expected, predicted)) | |
else: | |
correct = True | |
break | |
if correct: | |
section['correct'].append((a, b, c, expected)) | |
else: | |
section['incorrect'].append((a, b, c, expected)) | |
if section: | |
# store the last section, too | |
sections.append(section) | |
self.log_accuracy(section) | |
total = { | |
'section': 'total', | |
'correct': sum((s['correct'] for s in sections), []), | |
'incorrect': sum((s['incorrect'] for s in sections), []), | |
} | |
self.log_accuracy(total) | |
sections.append(total) | |
return sections |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment