Created
April 27, 2015 20:38
-
-
Save vierja/bcc5bc20de82a584c72d to your computer and use it in GitHub Desktop.
unknown_words.pynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "", | |
"signature": "sha256:f17bf110a15335ea4319a4b032f6171c9a8aee0cfdc22972143848ff97881f13" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Unknown words\n", | |
"Let's try to predict word vectors for unknown words. We need this for parsing, and for other models based on sentence stuff.\n", | |
"\n", | |
"I'm gonna try to re-use the training function for predicting.\n", | |
"\n", | |
"It would be especially cool if I could use the Google News vectors for this somehow, maybe by reconstructing the training matrix or something." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"import gensim\n", | |
"from numpy import dot, argsort, exp, absolute,all, sum as np_sum" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 129 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"from nltk.corpus import brown as brown_raw" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 53 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"brown = [filter(lambda a: a not in ['.',',','``',\"''\",'--'], x) for x in brown_raw.sents()]\n", | |
"with open(\"wsj_corpus\") as fileobject:\n", | |
" wsj = map(str.split, fileobject)\n", | |
"corpus = brown + wsj" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 54 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"model = gensim.models.word2vec.Word2Vec(size=300, min_count=1, sg=0)\n", | |
"model.build_vocab(corpus)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 238 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# We can initialize with the google vectors!\n", | |
"# model_goo = gensim.models.Word2Vec.load_word2vec_format('google_from_wsj') # reduced vector set\n", | |
"# for word in model_goo.vocab:\n", | |
"# if word in model:\n", | |
"# model.syn0[model.vocab[word].index] = model_goo[word]" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 239 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"model.train(corpus)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"metadata": {}, | |
"output_type": "pyout", | |
"prompt_number": 240, | |
"text": [ | |
"1798132" | |
] | |
} | |
], | |
"prompt_number": 240 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"That works! Amazing.\n", | |
"\n", | |
"Now, this function is not included in gensim but it's really handy:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"def nearest(self, mean, topn=1):\n", | |
" \"\"\"\n", | |
" Find the top-N nearest words.\n", | |
" \"\"\"\n", | |
" self.init_sims()\n", | |
"\n", | |
" dists = dot(self.syn0norm, mean)\n", | |
" if not topn:\n", | |
" return dists\n", | |
" best = argsort(dists)[::-1][:topn]\n", | |
" # ignore (don't return) words from the input\n", | |
" result = [(self.index2word[sim], float(dists[sim])) for sim in best]\n", | |
" return result[:topn]" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 29 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Training\n", | |
"OK, so `syn0` is the wordvector, and `syn1` seems to be some kind of weight vector per hierarchical node thing." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"sentence = wsj[9481]\n", | |
"words = sentence[0:4] + sentence[6:10]\n", | |
"\n", | |
"word_indices = [model.vocab[w].index for w in words]\n", | |
"l1 = np_sum(model.syn0[word_indices],axis=0)\n", | |
"\n", | |
"best = {'': float('inf')}\n", | |
"for word in model.vocab:\n", | |
" l2a = model.syn1[model.vocab[word].point] #2d matrix, codelen x layer1_size\n", | |
" fa = 1. / (1. + exp(-dot(l1, l2a.T)))\n", | |
" ssd = sum((1 - model.vocab[word].code - fa)**2) # ssd of vector of error gradients\n", | |
" for good, score in best.items():\n", | |
" if ssd < score:\n", | |
" best[word] = ssd\n", | |
" best = dict(sorted(best.items(), key=lambda t: t[1])[:10]) # best 10\n", | |
"\n", | |
"print ' '.join(sentence), \n", | |
"sorted(best.items(), key=lambda t: t[1])[:10]" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"That 's because municipal-bond interest is exempt from federal income tax and from state and local taxes too for in-state investors" | |
] | |
}, | |
{ | |
"metadata": {}, | |
"output_type": "pyout", | |
"prompt_number": 260, | |
"text": [ | |
"[(\"'s\", 0.20982208563704319),\n", | |
" ('%', 0.58884351599817819),\n", | |
" ('Panama', 1.0371112292142612),\n", | |
" ('net', 1.0403089068371625),\n", | |
" ('firm', 1.1973944275741957),\n", | |
" ('is', 1.2186806353929569),\n", | |
" ('was', 1.2446461367881057),\n", | |
" ('earnings', 1.2653812361146635),\n", | |
" ('Gonzalez', 1.3151883398624165),\n", | |
" ('distribution', 1.3285012866346193)]" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"prompt_number": 260 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [] | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment