Last active
February 27, 2017 15:11
-
-
Save qbx2/5ecc3b235a24da0b8b793f14663e2796 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import random | |
import copy | |
import sys | |
MINIBATCH_SIZE = 1000 | |
def gen_model(): | |
cell = tf.contrib.rnn.BasicRNNCell(128) | |
x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='ph_x') | |
s = tf.placeholder(tf.float32, shape=[None, cell.state_size], name='ph_s') | |
W = tf.Variable(tf.truncated_normal([cell.output_size, 4], stddev=0.1)) | |
b = tf.Variable(tf.constant(0., shape=[4])) | |
outputs, state = tf.nn.dynamic_rnn(cell, x, None, s) | |
guess = last_out = tf.transpose(outputs, [1, 0, 2])[-1] | |
guess = tf.matmul(last_out, W) + b | |
return cell, guess, tf.argmax(guess, axis=1), state, x, s | |
def gen_loss(guess): | |
a = tf.placeholder(tf.float32, shape=[None, 4], name='ph_a') | |
losses = tf.nn.softmax_cross_entropy_with_logits(logits=guess, labels=a) | |
return tf.reduce_mean(losses), a | |
def gen_minimize(loss): | |
return tf.train.RMSPropOptimizer(0.1, 0.99).minimize(loss) | |
onehot = [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] | |
if len(sys.argv) > 1: | |
TRAIN = sys.argv[1] != 'eval' | |
else: | |
TRAIN = True | |
with tf.Session() as sess: | |
cell, guess, argmaxguess, state, ph_x, ph_s = gen_model() | |
loss, ph_a = gen_loss(guess) | |
minimize = gen_minimize(loss) | |
saver = tf.train.Saver() | |
checkpoint = tf.train.get_checkpoint_state("/tmp/") | |
if checkpoint and checkpoint.model_checkpoint_path: | |
saver.restore(sess, checkpoint.model_checkpoint_path) | |
print('Checkpoint loaded') | |
print(W.eval()) | |
else: | |
sess.run(tf.global_variables_initializer()) | |
if TRAIN: | |
tmp = [0] * MINIBATCH_SIZE | |
tmp_loss = 0 | |
s = sess.run(cell.zero_state(MINIBATCH_SIZE, tf.float32)) | |
xs = [0] * MINIBATCH_SIZE | |
tmps = [0] * MINIBATCH_SIZE | |
for i in range(1000000): | |
for j in range(MINIBATCH_SIZE): | |
if random.random() < .1: | |
tmp[j] = 0 | |
s[j] = [0] * cell.state_size | |
a = random.randint(0, 1) | |
tmp[j] = ((tmp[j] << 1) | a) & 0b11 | |
xs[j] = [[a]] | |
tmps[j] = onehot[tmp[j]] | |
#ttt = (ttt + str(xs[0][0][0]))[-2:] | |
#print(ttt) | |
#assert onehot[int(ttt, 2)] == tmps[0] | |
s, out_loss, _ = sess.run([state, loss, minimize], | |
feed_dict={ph_x: xs, ph_s: s, ph_a: tmps}) | |
tmp_loss += out_loss | |
if i % 1000 == 0 and i: | |
print(tmp_loss/1000) | |
tmp_loss = 0 | |
if i % 10000 == 0: | |
saver.save(sess, "/tmp/model.ckpt") | |
print('Saved') | |
else: | |
s = sess.run(cell.zero_state(1, tf.float32)) | |
while True: | |
a = int(input('0 or 1 >>> ')) | |
xs = [[[a]]] | |
g, a, s = sess.run([guess, argmaxguess, state], feed_dict={ph_x: xs, ph_s: s}) | |
print(a[0], g) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment