Skip to content

Instantly share code, notes, and snippets.

@MichaelSnowden
Last active September 4, 2018 02:21
Show Gist options
  • Save MichaelSnowden/9b8b1e662c98c514d571f4d5c20c3a03 to your computer and use it in GitHub Desktop.
Save MichaelSnowden/9b8b1e662c98c514d571f4d5c20c3a03 to your computer and use it in GitHub Desktop.
An autocorrect tool that functions by mapping words to vectors, and levenshtein distance to euclidean distance
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Neural Autocorrect"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training\n",
"\n",
"Train a dense model that minimizes the squared error between the euclidean distance of word vectors, and the levenshtein distance. The advantage of euclidean distance is that it allows us to create an index to quickly retrieve nearest neighbors. If the index only contains words spelled correctly, then this could function as an autocorrect tool.\n",
"\n",
"The model is structured like this:\n",
"\n",
"```\n",
"embedding1 = FFNN(concatenated_char_embeddings1)\n",
"embedding2 = FFNN(concatenated_char_embeddings2)\n",
"euclidean_distance = l2_distance(embedding1, embedding2)\n",
"loss = average_squared_error(euclidean_distance, levenshtein_distance)\n",
"```\n",
"\n",
"Words under the character limit are zero-padded on the right, and the index 0 is treated as its own character, with its own embedding."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving epoch: 1, loss = 3.4607551097869873\n",
"saving epoch: 10, loss = 2.0638227462768555\n",
"saving epoch: 20, loss = 1.3220642805099487\n",
"saving epoch: 30, loss = 1.1543822288513184\n",
"saving epoch: 40, loss = 1.0564409494400024\n",
"saving epoch: 50, loss = 0.999796450138092\n",
"saving epoch: 60, loss = 1.011827826499939\n",
"saving epoch: 70, loss = 0.9703212976455688\n",
"saving epoch: 80, loss = 0.9654067158699036\n",
"saving epoch: 90, loss = 0.9628726243972778\n",
"saving epoch: 100, loss = 0.9222937822341919\n",
"saving epoch: 110, loss = 0.8789178133010864\n",
"saving epoch: 120, loss = 0.9058818817138672\n",
"saving epoch: 130, loss = 0.8917433023452759\n",
"saving epoch: 140, loss = 0.911494255065918\n",
"saving epoch: 150, loss = 0.9118928909301758\n",
"saving epoch: 160, loss = 0.8912183046340942\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-66829473bd9d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mword1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mword2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mword1\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mword2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 81\u001b[0;31m \u001b[0mtrue_distances_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlevenshtein\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mword1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mword2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 82\u001b[0m \u001b[0mword1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mword1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0mword2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mword2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-9-66829473bd9d>\u001b[0m in \u001b[0;36mlevenshtein\u001b[0;34m(s, t)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mcost\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0mv1\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv1\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv0\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv0\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcost\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mv0\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mv1\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import traceback\n",
"import os.path\n",
"\n",
"tf.reset_default_graph()\n",
"\n",
"n_chars = 128\n",
"max_length = 32\n",
"char_embedding_dim = n_chars+1\n",
"word_embedding_dim = 512\n",
"hidden_layer_size = word_embedding_dim\n",
"\n",
"words1 = tf.placeholder(dtype=tf.int32, shape=[None, max_length])\n",
"words2 = tf.placeholder(dtype=tf.int32, shape=[None, max_length])\n",
"true_distances = tf.placeholder(dtype=tf.float32, shape=[None])\n",
"char_embeddings = tf.Variable(tf.random_normal(shape=[n_chars+1, char_embedding_dim]))\n",
"layer1 = tf.layers.Dense(hidden_layer_size, activation=tf.nn.relu)\n",
"layer2 = tf.layers.Dense(word_embedding_dim, activation=None)\n",
"word_embeddings1 = layer2(layer1(tf.reshape(\n",
" tf.nn.embedding_lookup(char_embeddings, words1), [-1, max_length * char_embedding_dim])))\n",
"word_embeddings2 = layer2(layer1(tf.reshape(\n",
" tf.nn.embedding_lookup(char_embeddings, words2), [-1, max_length * char_embedding_dim])))\n",
"pred_distances = tf.sqrt(tf.reduce_sum(tf.square(word_embeddings1 - word_embeddings2), axis=1))\n",
"error = pred_distances - true_distances\n",
"avg_squared_error = tf.reduce_mean(tf.square(error))\n",
"# Average absolute error between the levenshtein distance and the euclidean distance \n",
"# If this is 1.0, for example, that means that the true levenshtein distance is, on average,\n",
"# 1 unit away from our predicted distance\n",
"avg_absolute_error = tf.reduce_mean(tf.abs(error))\n",
"optimize = tf.train.AdamOptimizer().minimize(avg_squared_error)\n",
"\n",
"# https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance#Python\n",
"def levenshtein(s, t):\n",
" ''' From Wikipedia article; Iterative with two matrix rows. '''\n",
" if len(s) == len(t) and np.all(s == t): return 0\n",
" elif len(s) == 0: return len(t)\n",
" elif len(t) == 0: return len(s)\n",
" v0 = [None] * (len(t) + 1)\n",
" v1 = [None] * (len(t) + 1)\n",
" for i in range(len(v0)):\n",
" v0[i] = i\n",
" for i in range(len(s)):\n",
" v1[0] = i + 1\n",
" for j in range(len(t)):\n",
" cost = 0 if s[i] == t[j] else 1\n",
" v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost)\n",
" for j in range(len(v0)):\n",
" v0[j] = v1[j]\n",
"\n",
" return v1[len(t)]\n",
"\n",
"def random_word():\n",
" length = np.random.randint(1, max_length+1)\n",
" return np.random.choice(n_chars, length, replace=False) + 1\n",
" \n",
"def pad(word):\n",
" return np.concatenate((word, np.zeros((max_length - len(word)))))\n",
"\n",
"batch_size = 2048\n",
"\n",
"saver = tf.train.Saver()\n",
"with tf.Session() as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
" if os.path.isfile('/tmp/model.ckpt.meta'): \n",
" try:\n",
" tf.train.import_meta_graph('/tmp/model.ckpt.meta').restore(sess, \"/tmp/model.ckpt\")\n",
" except:\n",
" traceback.print_exc()\n",
" print('error while restoring')\n",
" e = 0\n",
" while True:\n",
" words1_ = []\n",
" words2_ = []\n",
" true_distances_ = []\n",
" for b in range(batch_size):\n",
" word1 = random_word()\n",
" word2 = random_word()\n",
" if len(word1) == len(word2) and np.all(word1 == word2):\n",
" continue\n",
" true_distances_.append(levenshtein(word1, word2))\n",
" word1 = pad(word1)\n",
" word2 = pad(word2)\n",
" words1_.append(word1)\n",
" words2_.append(word2)\n",
" feed = {\n",
" words1: words1_,\n",
" words2: words2_,\n",
" true_distances: true_distances_\n",
" }\n",
" avg_absolute_error_, _ = sess.run([avg_absolute_error, optimize], feed)\n",
" if e == 0 or (e+1) % 10 == 0:\n",
" saver.save(sess, \"/tmp/model.ckpt\")\n",
" print('saving epoch: {}, loss = {}'.format(e+1, avg_absolute_error_))\n",
" e += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Indexing\n",
"\n",
"Build a euclidean KDTree of the word vectors from `/usr/share/dict/words`"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /tmp/model.ckpt\n",
"index created with 100411 words\n"
]
}
],
"source": [
"from sklearn.neighbors import KDTree\n",
"\n",
"with open('/usr/share/dict/words') as f:\n",
" words = [line.rstrip('\\n') for line in f.readlines()]\n",
"\n",
"chars = {}\n",
"for w in words:\n",
" for c in w:\n",
" if c not in chars:\n",
" chars[c] = len(chars) + 1\n",
"\n",
"def encode(word):\n",
" return pad([chars[c] for c in word])\n",
" \n",
"with tf.Session() as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
" tf.train.import_meta_graph('/tmp/model.ckpt.meta').restore(sess, \"/tmp/model.ckpt\")\n",
" word_embeddings_ = sess.run(word_embeddings1, {\n",
" words1: [encode(w) for w in words]\n",
" })\n",
" tree = KDTree(word_embeddings_, metric='euclidean')\n",
" print('index created with {} words'.format(len(word_embeddings_)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Retrieval\n",
"\n",
"Fetch nearest neighbors for some arbitrary text. Characters filtered to the indexed charset, and 32 char limit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /tmp/model.ckpt\n",
"Enter some text: time\n",
"sanitized text = time\n",
"nearest neighbors\n",
"\ttime euclidean = 0.00, levenshtein = 0\n",
"\ttame euclidean = 1.94, levenshtein = 1\n",
"\tdime euclidean = 2.00, levenshtein = 1\n",
"\tlime euclidean = 2.02, levenshtein = 1\n",
"\trime euclidean = 2.03, levenshtein = 1\n",
"\tdim euclidean = 2.07, levenshtein = 2\n",
"\tMimi euclidean = 2.13, levenshtein = 2\n",
"\tvim euclidean = 2.17, levenshtein = 2\n",
"\ttome euclidean = 2.29, levenshtein = 1\n",
"\ttire euclidean = 2.32, levenshtein = 1\n",
"Enter some text: hello\n",
"sanitized text = hello\n",
"nearest neighbors\n",
"\thello euclidean = 0.00, levenshtein = 0\n",
"\tjello euclidean = 1.94, levenshtein = 1\n",
"\tcello euclidean = 2.05, levenshtein = 1\n",
"\thellos euclidean = 2.53, levenshtein = 1\n",
"\thell euclidean = 2.59, levenshtein = 1\n",
"\tGallo euclidean = 2.66, levenshtein = 2\n",
"\tbelle euclidean = 2.80, levenshtein = 2\n",
"\ttell euclidean = 2.85, levenshtein = 2\n",
"\theels euclidean = 2.88, levenshtein = 2\n",
"\ttells euclidean = 2.95, levenshtein = 2\n",
"Enter some text: schwarzenegger\n",
"sanitized text = schwarzenegger\n",
"nearest neighbors\n",
"\tSchwarzenegger euclidean = 3.33, levenshtein = 1\n",
"\tSchwarzenegger's euclidean = 7.59, levenshtein = 3\n",
"\tstewardess's euclidean = 10.04, levenshtein = 9\n",
"\tstatesmanlike euclidean = 10.14, levenshtein = 11\n",
"\tschoolteacher euclidean = 10.28, levenshtein = 8\n",
"\tstewardesses euclidean = 10.30, levenshtein = 8\n",
"\tSchwarzkopf's euclidean = 10.32, levenshtein = 8\n",
"\tschoolchildren euclidean = 10.39, levenshtein = 10\n",
"\tdemonstrative euclidean = 10.54, levenshtein = 13\n",
"\tschmaltziest euclidean = 10.59, levenshtein = 9\n"
]
}
],
"source": [
"with tf.Session() as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
" tf.train.import_meta_graph('/tmp/model.ckpt.meta').restore(sess, \"/tmp/model.ckpt\")\n",
" while True:\n",
" text = input('Enter some text: '.format(max_length))\n",
" text = ''.join(c for c in text if c in chars)[:max_length]\n",
" word = encode(text)\n",
" embedding = sess.run(word_embeddings1, {\n",
" words1: [word]\n",
" })[0]\n",
" distances, indices = tree.query([embedding], 10)\n",
" max_width = max(len(words[i]) for i in indices[0])\n",
" print('Sanitized text = {}'.format(text))\n",
" print('Nearest neighbors')\n",
" for d, i in zip(distances[0], indices[0]):\n",
" l = levenshtein(text, words[i])\n",
" print('\\t{} euclidean = {:.2f}, levenshtein = {}'.format(words[i].ljust(max_width+1), d, l))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment