Skip to content

Instantly share code, notes, and snippets.

@qbx2
Last active February 27, 2017 15:11
Show Gist options
  • Save qbx2/5ecc3b235a24da0b8b793f14663e2796 to your computer and use it in GitHub Desktop.
Save qbx2/5ecc3b235a24da0b8b793f14663e2796 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import random
import copy
import sys
MINIBATCH_SIZE = 1000
def gen_model():
cell = tf.contrib.rnn.BasicRNNCell(128)
x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='ph_x')
s = tf.placeholder(tf.float32, shape=[None, cell.state_size], name='ph_s')
W = tf.Variable(tf.truncated_normal([cell.output_size, 4], stddev=0.1))
b = tf.Variable(tf.constant(0., shape=[4]))
outputs, state = tf.nn.dynamic_rnn(cell, x, None, s)
guess = last_out = tf.transpose(outputs, [1, 0, 2])[-1]
guess = tf.matmul(last_out, W) + b
return cell, guess, tf.argmax(guess, axis=1), state, x, s
def gen_loss(guess):
a = tf.placeholder(tf.float32, shape=[None, 4], name='ph_a')
losses = tf.nn.softmax_cross_entropy_with_logits(logits=guess, labels=a)
return tf.reduce_mean(losses), a
def gen_minimize(loss):
return tf.train.RMSPropOptimizer(0.1, 0.99).minimize(loss)
onehot = [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]
if len(sys.argv) > 1:
TRAIN = sys.argv[1] != 'eval'
else:
TRAIN = True
with tf.Session() as sess:
cell, guess, argmaxguess, state, ph_x, ph_s = gen_model()
loss, ph_a = gen_loss(guess)
minimize = gen_minimize(loss)
saver = tf.train.Saver()
checkpoint = tf.train.get_checkpoint_state("/tmp/")
if checkpoint and checkpoint.model_checkpoint_path:
saver.restore(sess, checkpoint.model_checkpoint_path)
print('Checkpoint loaded')
print(W.eval())
else:
sess.run(tf.global_variables_initializer())
if TRAIN:
tmp = [0] * MINIBATCH_SIZE
tmp_loss = 0
s = sess.run(cell.zero_state(MINIBATCH_SIZE, tf.float32))
xs = [0] * MINIBATCH_SIZE
tmps = [0] * MINIBATCH_SIZE
for i in range(1000000):
for j in range(MINIBATCH_SIZE):
if random.random() < .1:
tmp[j] = 0
s[j] = [0] * cell.state_size
a = random.randint(0, 1)
tmp[j] = ((tmp[j] << 1) | a) & 0b11
xs[j] = [[a]]
tmps[j] = onehot[tmp[j]]
#ttt = (ttt + str(xs[0][0][0]))[-2:]
#print(ttt)
#assert onehot[int(ttt, 2)] == tmps[0]
s, out_loss, _ = sess.run([state, loss, minimize],
feed_dict={ph_x: xs, ph_s: s, ph_a: tmps})
tmp_loss += out_loss
if i % 1000 == 0 and i:
print(tmp_loss/1000)
tmp_loss = 0
if i % 10000 == 0:
saver.save(sess, "/tmp/model.ckpt")
print('Saved')
else:
s = sess.run(cell.zero_state(1, tf.float32))
while True:
a = int(input('0 or 1 >>> '))
xs = [[[a]]]
g, a, s = sess.run([guess, argmaxguess, state], feed_dict={ph_x: xs, ph_s: s})
print(a[0], g)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment