Last active
December 11, 2018 02:06
-
-
Save j-min/481749dcb853b4477c4f441bf7452195 to your computer and use it in GitHub Desktop.
TensorFlow 0.9 implementation of BasicRNNCell based on hunkim's tutorial
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### BasicRNNCell\n", | |
"#### TensorFlow 0.9 implementation based on hunkim's tutorial\n", | |
"https://hunkim.github.io/ml/\n", | |
"\n", | |
"https://www.youtube.com/watch?v=A8wJYfDUYCk&feature=youtu.be" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'o': 3, 'l': 2, 'e': 1, 'h': 0}\n" | |
] | |
} | |
], | |
"source": [ | |
"char_rdic = ['h', 'e', 'l', 'o'] # id -> char\n", | |
"char_dic = {w : i for i, w in enumerate(char_rdic)} # char -> id\n", | |
"print (char_dic)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[0, 1, 2, 2, 3]\n" | |
] | |
} | |
], | |
"source": [ | |
"ground_truth = [char_dic[c] for c in 'hello']\n", | |
"print (ground_truth)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"x_data = np.array([[1,0,0,0], # h\n", | |
" [0,1,0,0], # e\n", | |
" [0,0,1,0], # l\n", | |
" [0,0,1,0]], # l\n", | |
" dtype = 'f')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Tensor(\"one_hot:0\", shape=(4, 4), dtype=float32)\n" | |
] | |
} | |
], | |
"source": [ | |
"x_data = tf.one_hot(ground_truth[:-1], len(char_dic), 1.0, 0.0, -1)\n", | |
"print(x_data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# Configuration\n", | |
"rnn_size = len(char_dic) # 4\n", | |
"batch_size = 1\n", | |
"output_size = 4" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<tensorflow.python.ops.rnn_cell.BasicRNNCell object at 0x7effb759c9e8>\n" | |
] | |
} | |
], | |
"source": [ | |
"# RNN Model\n", | |
"rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units = rnn_size,\n", | |
" input_size = None, # deprecated at tensorflow 0.9\n", | |
" #activation = tanh,\n", | |
" )\n", | |
"print(rnn_cell)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Tensor(\"zeros:0\", shape=(1, 4), dtype=float32)\n" | |
] | |
} | |
], | |
"source": [ | |
"initial_state = rnn_cell.zero_state(batch_size, tf.float32)\n", | |
"print(initial_state)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Tensor(\"zeros_1:0\", shape=(1, 4), dtype=float32)\n" | |
] | |
} | |
], | |
"source": [ | |
"initial_state_1 = tf.zeros([batch_size, rnn_cell.state_size]) # 위 코드와 같은 결과\n", | |
"print(initial_state_1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[<tf.Tensor 'split:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'split:1' shape=(1, 4) dtype=float32>, <tf.Tensor 'split:2' shape=(1, 4) dtype=float32>, <tf.Tensor 'split:3' shape=(1, 4) dtype=float32>]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"'\\n[[1,0,0,0]] # h\\n[[0,1,0,0]] # e\\n[[0,0,1,0]] # l\\n[[0,0,1,0]] # l\\n'" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x_split = tf.split(0, len(char_dic), x_data) # 가로축으로 4개로 split\n", | |
"print(x_split)\n", | |
"\"\"\"\n", | |
"[[1,0,0,0]] # h\n", | |
"[[0,1,0,0]] # e\n", | |
"[[0,0,1,0]] # l\n", | |
"[[0,0,1,0]] # l\n", | |
"\"\"\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"outputs, state = tf.nn.rnn(cell = rnn_cell, inputs = x_split, initial_state = initial_state)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[<tf.Tensor 'RNN/BasicRNNCell/Tanh:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'RNN/BasicRNNCell_1/Tanh:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'RNN/BasicRNNCell_2/Tanh:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'RNN/BasicRNNCell_3/Tanh:0' shape=(1, 4) dtype=float32>]\n", | |
"Tensor(\"RNN/BasicRNNCell_3/Tanh:0\", shape=(1, 4), dtype=float32)\n" | |
] | |
} | |
], | |
"source": [ | |
"print (outputs)\n", | |
"print (state)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'\\n[[logit from 1st output],\\n[logit from 2nd output],\\n[logit from 3rd output],\\n[logit from 4th output]]\\n'" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"logits = tf.reshape(tf.concat(1, outputs), # shape = 1 x 16\n", | |
" [-1, rnn_size]) # shape = 4 x 4\n", | |
"logits.get_shape()\n", | |
"\"\"\"\n", | |
"[[logit from 1st output],\n", | |
"[logit from 2nd output],\n", | |
"[logit from 3rd output],\n", | |
"[logit from 4th output]]\n", | |
"\"\"\"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"TensorShape([Dimension(4)])" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"targets = tf.reshape(ground_truth[1:], [-1]) # a shape of [-1] flattens into 1-D\n", | |
"targets.get_shape()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"weights = tf.ones([len(char_dic) * batch_size])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"loss = tf.nn.seq2seq.sequence_loss_by_example([logits], [targets], [weights])\n", | |
"cost = tf.reduce_sum(loss) / batch_size\n", | |
"train_op = tf.train.RMSPropOptimizer(0.01, 0.9).minimize(cost)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
"[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
"[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
"[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
"[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
"[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
"[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
"[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
"[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
"[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
"[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
"[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
"[1 2 2 3] ['e', 'l', 'l', 'o']\n" | |
] | |
} | |
], | |
"source": [ | |
"# Launch the graph in a session\n", | |
"with tf.Session() as sess:\n", | |
" tf.initialize_all_variables().run()\n", | |
" for i in range(100):\n", | |
" sess.run(train_op)\n", | |
" result = sess.run(tf.argmax(logits, 1))\n", | |
" print(result, [char_rdic[t] for t in result]) " | |
] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python [tensorflow]", | |
"language": "python", | |
"name": "Python [tensorflow]" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
@goddoe 수정했습니다. 지적해주셔서 감사합니다!
좋은 예제 감사합니다. 위 예제를 버전 1.0에서 실행하기 위해서는 링크된 글을 참고하시길 바랍니다.
위에 모든분들 너무나 감사합니다. 혼자 실습중인데 너무나 오류나서 힘들어하고 있었는데 1.0버젼에 맞춰서 코딩수정까지 자료가 있으니 너무나 힘이 됩니다 ! 감사합니다 !
tensor flow 1.3이상 버전에서는
In [17]을
# Launch the graph in a session
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(100):
sess.run(train_op, )
result = sess.run(tf.argmax(logits, axis=1))
print(result, [char_rdic[t] for t in result])
로 해야 작동됩니다.
1.8.0. 버전에서
[10]
x_split = tf.split(0, len(char_dic), x_data) 을
x_split = tf.split(x_data, len(char_dic), 0)
으로 해야 돌아갑니다.
1.8.0에 맞게 수정해서 올렸습니다.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
4 번째 셀에
x_data = np.array([[1,0,0,0], # h
[0,1,0,0], # e
[0,0,1,0], # l
[0,0,0,1]], # l 이부분이 잘못된 것같습니다 [0,0,1,0] 로 되어야하지 않을까요
dtype = 'f')