Skip to content

Instantly share code, notes, and snippets.

@alshell7
Created April 5, 2019 07:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alshell7/a366da48903aec6a2d765b4691902521 to your computer and use it in GitHub Desktop.
Save alshell7/a366da48903aec6a2d765b4691902521 to your computer and use it in GitHub Desktop.
NNLM-Tensor.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "NNLM-Tensor.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/alshell7/a366da48903aec6a2d765b4691902521/nnlm-tensor.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"metadata": {
"id": "4dTc8F-IObVT",
"colab_type": "code",
"outputId": "20bfe275-33af-4655-b2ee-cecdea69cce7",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 125
}
},
"cell_type": "code",
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"\n",
"tf.reset_default_graph()\n",
"\n",
"sentences = [ \"i like cat\", \"i love coffee\", \"i hate milk\"]\n",
"\n",
"word_list = \" \".join(sentences).split()\n",
"word_list = list(set(word_list))\n",
"word_dict = {w: i for i, w in enumerate(word_list)}\n",
"number_dict = {i: w for i, w in enumerate(word_list)}\n",
"n_class = len(word_dict) # number of Vocabulary\n",
"\n",
"# NNLM Parameter\n",
"n_step = 2 # number of steps ['i like', 'i love', 'i hate']\n",
"n_hidden = 2 # number of hidden units\n",
"\n",
"def make_batch(sentences):\n",
" input_batch = []\n",
" target_batch = []\n",
"\n",
" for sen in sentences:\n",
" word = sen.split()\n",
" input = [word_dict[n] for n in word[:-1]]\n",
" target = word_dict[word[-1]]\n",
"\n",
" input_batch.append(np.eye(n_class)[input])\n",
" target_batch.append(np.eye(n_class)[target])\n",
"\n",
" return input_batch, target_batch\n",
"\n",
"# Model\n",
"X = tf.placeholder(tf.float32, [None, n_step, n_class]) # [batch_size, number of steps, number of Vocabulary]\n",
"Y = tf.placeholder(tf.float32, [None, n_class])\n",
"\n",
"input = tf.reshape(X, shape=[-1, n_step * n_class]) # [batch_size, n_step * n_class]\n",
"H = tf.Variable(tf.random_normal([n_step * n_class, n_hidden]))\n",
"d = tf.Variable(tf.random_normal([n_hidden]))\n",
"U = tf.Variable(tf.random_normal([n_hidden, n_class]))\n",
"b = tf.Variable(tf.random_normal([n_class]))\n",
"\n",
"tanh = tf.nn.tanh(d + tf.matmul(input, H)) # [batch_size, n_hidden]\n",
"model = tf.matmul(tanh, U) + b # [batch_size, n_class]\n",
"\n",
"cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=model, labels=Y))\n",
"optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)\n",
"prediction =tf.argmax(model, 1)\n",
"\n",
"# Training\n",
"init = tf.global_variables_initializer()\n",
"sess = tf.Session()\n",
"sess.run(init)\n",
"\n",
"input_batch, target_batch = make_batch(sentences)\n",
"\n",
"for epoch in range(5000):\n",
" _, loss = sess.run([optimizer, cost], feed_dict={X: input_batch, Y: target_batch})\n",
" if (epoch + 1)%1000 == 0:\n",
" print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
"\n",
"# Predict\n",
"predict = sess.run([prediction], feed_dict={X: input_batch})\n",
"\n",
"# Test\n",
"input = [sen.split()[:2] for sen in sentences]\n",
"print([sen.split()[:2] for sen in sentences], '->', [number_dict[n] for n in predict[0]])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch: 1000 cost = 0.468996\n",
"Epoch: 2000 cost = 0.125292\n",
"Epoch: 3000 cost = 0.045533\n",
"Epoch: 4000 cost = 0.020888\n",
"Epoch: 5000 cost = 0.010938\n",
"[['i', 'like'], ['i', 'love'], ['i', 'hate']] -> ['dog', 'coffee', 'milk']\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment