Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save aparrish/29574a35a51e955d4c6284da42b8f53a to your computer and use it in GitHub Desktop.
Save aparrish/29574a35a51e955d4c6284da42b8f53a to your computer and use it in GitHub Desktop.
Predictive text with concatenated word vectors. Code examples released under CC0 https://creativecommons.org/choose/zero/, other text released under CC BY 4.0 https://creativecommons.org/licenses/by/4.0/
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Predictive text with concatenated word vectors\n",
"\n",
"By [Allison Parrish](http://www.decontextualize.com/)\n",
"\n",
"This notebook demonstrates one way to implement predictive text like [iOS QuickType](https://www.apple.com/sg/ios/whats-new/quicktype/). It works sort of like a [Markov chain text generator](https://github.com/aparrish/rwet/blob/master/ngrams-and-markov-chains.ipynb), but uses nearest-neighbor lookups on a database of concatenated [word vectors](https://github.com/aparrish/rwet/blob/master/understanding-word-vectors.ipynb) instead of n-grams of tokens. You can build this database with any text you want!\n",
"\n",
"To get this code to work, you'll need to [install spaCy](https://spacy.io/usage/#section-quickstart), download a [spaCy model with word vectors](https://spacy.io/usage/models#available) (like `en_core_web_lg`). You'll also need [Simple Neighbors](https://github.com/aparrish/simpleneighbors), a Python library I made for easy nearest neighbor lookups:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"!pip install simpleneighbors"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## How it works\n",
"\n",
"The goal of a predictive text interface is to look at what the user has typed so far and then suggest the word that is most likely to come next. The system in this notebook does this by looking at each sequence of words of a particular length `n`, and then looking up the word vector in spaCy for each of those words, concatenating them to create one long vector. It then stores that vector along with the word that *follows* the sequence.\n",
"\n",
"To calculate suggestions for a particular text from this database, you can just look at the last `n` words in the text, concatenate the word vectors for that stretch of words, and then find the entries in the database whose vector is nearest. The words stored along with those sequences (i.e., the words that followed the original sequence) are the words the system suggests as most likely to come next.\n",
"\n",
"So let's implement it! First, we'll import the libraries we need:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from simpleneighbors import SimpleNeighbors"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import spacy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And then load the spaCy model. (This will take a few seconds.)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"nlp = spacy.load('en_core_web_lg')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You'll need to have some text to use to build the database. If you're following along, download a [plain text file from Project Gutenberg](https://www.gutenberg.org/) to the same directory as this notebook and put its filename below."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"filename = \"1342-0.txt\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When you're parsing a text with spaCy, it can use up a lot of memory and either throw \"out of memory\" errors or cause your computer to slow down as it swaps memory to disk. To ameliorate this, we're only going to train on the first 500k characters of the text. You can change the number in the cell below if you want even fewer characters (or more)."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"cutoff = 500000"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code in the cell below parses your text file into sentences (this might take a few seconds):"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"doc = nlp(open(filename).read()[:cutoff], \n",
" disable=['tagger'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `concatenate_vectors` function below takes a sequence of spaCy tokens (like those that you get when you parse a text) and returns the *concatenated* word vectors of those tokens. \"Concatenating\" vectors means to make one big vector from several smaller vectors simply by lining them all up. For example, if you had three 2D vectors `a`, `b`, and `c`:\n",
"\n",
" a = (1, 2)\n",
" b = (5, 6)\n",
" c = (11, 12)\n",
" \n",
"The concatenation of these vectors would be this six-dimensional vector:\n",
"\n",
" (1, 2, 5, 6, 11, 12)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(600,)"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"def concatenate_vectors(seq):\n",
" return np.concatenate(np.array([w.vector for w in seq]), axis=0)\n",
"concatenate_vectors(nlp(\"hello there\")).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using vectors instead of tokens is a simple way of coping with predicting the next word even for sequences that aren't found in the source text. Using concatenated vectors facilitates finding entries that have both similar meanings and similar word orders (which is important when predicting the next word in a text).\n",
"\n",
"The code in the cell below builds the nearest neighbor index that maps the concatenated vectors for each sequence of words in the source text to the word that follows. You can adjust `n` to change the length of the sequence considered. (In my experiments, values from 2–4 usually work best.)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"n = 3\n",
"nns = SimpleNeighbors(n*300)\n",
"for seq in doc.sents:\n",
" seq = [item for item in seq if item.is_alpha]\n",
" for i in range(len(seq)-n):\n",
" mean = concatenate_vectors(seq[i:i+n])\n",
" next_item = seq[i+n].text\n",
" nns.add_one(next_item, mean)\n",
"nns.build()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Once the index is built, you can test it out! Plug in a phrase with three words into the `start` variable below and run the cell. You'll see the top-ten most likely words to come next, as suggested by the nearest neighbor lookup."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['acknowledged',\n",
" 'liked',\n",
" 'been',\n",
" 'seen',\n",
" 'heard',\n",
" 'desired',\n",
" 'known',\n",
" 'read',\n",
" 'met',\n",
" 'seen',\n",
" 'supposed',\n",
" 'observed']"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"start = \"I have never\"\n",
"nns.nearest(concatenate_vectors(nlp(start)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Interactive web version\n",
"\n",
"The code below starts a [Flask](http://flask.pocoo.org/) web server on your computer to serve up an interactive version of the suggestion code. Run the cell and click on the link that appears below. If you make changes, make sure to interrupt the kernel before re-running the cell. You can interrupt the kernel either via the menu bar (`Kernel > Interrupt`) or by hitting Escape and typing `i` twice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"autocomplete_html = \"\"\"\n",
"<style type=\"text/css\">\n",
" * {\n",
" box-sizing: border-box;\n",
" font-family: sans-serif;\n",
" }\n",
" #suggestions {\n",
" width: 33%;\n",
" }\n",
" #suggestions p {\n",
" margin: 0;\n",
" width: 20%;\n",
" background-color: #aaa;\n",
" color: white;\n",
" float: left;\n",
" padding: 5px;\n",
" border: 1px white solid;\n",
" text-align: center;\n",
" font-size: 16px;\n",
" }\n",
"</style>\n",
"<textarea id=\"typehere\" placeholder=\"type here!\"\n",
" style=\"width: 33%;\n",
" padding: 0.5em;\n",
" font-family: sans-serif;\n",
" font-size: 16px;\"\n",
" rows=\"16\"></textarea>\n",
"<div id=\"suggestions\">\n",
"(suggestions will appear here)\n",
"</div>\n",
"<script>\n",
"function createChoice(val) {\n",
" let tn = document.createTextNode(val);\n",
" let ptag = document.createElement('p');\n",
" ptag.appendChild(tn);\n",
" ptag.onclick = function() {\n",
" addText(\" \" + val);\n",
" }\n",
" return ptag;\n",
"}\n",
"function addText(newText) {\n",
" let el = document.querySelector('#typehere');\n",
" var start = el.selectionStart\n",
" var end = el.selectionEnd\n",
" var text = el.value\n",
" var before = text.substring(0, start)\n",
" var after = text.substring(end, text.length) \n",
" el.value = (before + newText + after)\n",
" el.selectionStart = el.selectionEnd = start + newText.length\n",
" el.focus()\n",
" el.onkeyup()\n",
"}\n",
"document.querySelector('#typehere').onkeyup = async function() {\n",
" console.log(\"hi\");\n",
" let el = document.querySelector('#typehere');\n",
" let val = el.value;\n",
"\n",
" var start = el.selectionStart\n",
" var end = el.selectionEnd\n",
" var text = el.value\n",
" var before = text.substring(0, start)\n",
" \n",
" let resp = await getResp(before);\n",
" console.log(resp);\n",
" let suggestdiv = document.getElementById(\"suggestions\");\n",
" suggestdiv.innerHTML = \"\";\n",
" for (let s of resp) {\n",
" suggestdiv.appendChild(createChoice(s))\n",
" }\n",
"};\n",
"async function getResp(val) {\n",
" let resp = await fetch(\"/suggest.json?text=\" + \n",
" encodeURIComponent(val));\n",
" let data = await resp.json();\n",
" return data['suggestions'];\n",
"}\n",
"</script>\n",
"\"\"\"\n",
"from flask import Flask, request, jsonify\n",
"app = Flask(__name__)\n",
"@app.route(\"/suggest.json\")\n",
"def suggest():\n",
" text = request.args['text']\n",
" parsed = list(nlp(text, disable=['tagger', 'parser']))\n",
" if len(parsed) >= n:\n",
" suggestions = nns.nearest(concatenate_vectors(parsed[-n:]), 5)\n",
" else:\n",
" suggestions = []\n",
" return jsonify(\n",
" {'suggestions': suggestions})\n",
"@app.route(\"/\")\n",
"def home():\n",
" return autocomplete_html\n",
"app.run()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment