Skip to content

Instantly share code, notes, and snippets.

@tilarids
Last active March 22, 2017 04:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tilarids/a5fdc8601dfb96070a2b125cb7648795 to your computer and use it in GitHub Desktop.
Save tilarids/a5fdc8601dfb96070a2b125cb7648795 to your computer and use it in GitHub Desktop.
tf vs theano (simple example)
import tensorflow as tf
import numpy as np
trX = np.linspace(-1, 1, 1001, dtype=np.float32)
trY = 2 * trX + np.random.randn(*trX.shape).astype(np.float32) * 0.33
EPOCHS = 1000
CAP = 1000
queue_xy = tf.train.input_producer(tf.pack([trX, trY], axis=1), shuffle=False, capacity=CAP, num_epochs=EPOCHS)
W = tf.Variable(0.0, name="weights", dtype=tf.float32)
opt = tf.train.GradientDescentOptimizer(0.01)
gs = tf.Variable(0)
def body(i):
x,y = tf.unpack(queue_xy.dequeue())
cost = tf.square(y - tf.mul(x,W))
train_op = opt.minimize(cost, global_step=gs)
return tf.tuple([tf.add(i,1)], control_inputs=[train_op])
loop = tf.while_loop(lambda _: True, body, [tf.constant(0)])
with tf.Session() as sess:
tf.initialize_all_variables().run()
tf.initialize_local_variables().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
i = 0
try:
while not coord.should_stop():
sess.run(loop)
except tf.errors.OutOfRangeError:
print('Done training')
finally:
coord.request_stop()
coord.join(threads)
print sess.run(W)
import theano
from theano import tensor as T
import numpy as np
trX = np.linspace(-1, 1, 1001)
trY = 2 * trX + np.random.randn(*trX.shape) * 0.33
X = T.scalar()
Y = T.scalar()
W = theano.shared(np.asarray(0., dtype=theano.config.floatX))
Y_p = X * W
cost = T.mean(T.sqr(Y - Y_p))
gradient = T.grad(cost=cost, wrt=W)
updates = [[W, W - gradient * 0.01]]
train = theano.function(inputs=[X, Y], outputs=cost, updates=updates, allow_input_downcast=True)
for i in range(1000):
for x, y in zip(trX, trY):
train(x, y)
print W.get_value()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment