Skip to content

Instantly share code, notes, and snippets.

@trcook
Last active March 21, 2018 19:00
Show Gist options
  • Save trcook/9fc8698cf7dc848a953f8e7a7e5f1aad to your computer and use it in GitHub Desktop.
Save trcook/9fc8698cf7dc848a953f8e7a7e5f1aad to your computer and use it in GitHub Desktop.
basic mix of tensorflow and keras in the same file. Also makes use of tf.data
from keras import backend as K
import keras as k
from keras.layers import *
import tensorflow as tf
tf.reset_default_graph()
g=tf.get_default_graph()
sess=tf.Session()
K.set_session(sess)
steps_epoch=100
total_epochs=30
batch_size=4
total_repeats=steps_epoch*total_epochs*batch_size
def train_fn():
x=tf.constant([
[[0,1]],
[[0,0]],
[[1,0]],
[[1,1]]
],dtype=tf.float32)
y=tf.constant([
[[1]],
[[0]],
[[1]],
[[0]]
],dtype=tf.int32)
data=tf.data.Dataset.from_tensor_slices((x,y))
data=data.repeat(total_repeats)
data=data.batch(batch_size)
return data
def val_fn():
x=tf.constant([
[[1,1]],
[[1,0]],
[[1,0]],
[[1,1]]
],dtype=tf.float32)
y=tf.constant([
[[0]],
[[1]],
[[1]],
[[0]]
],dtype=tf.int32)
data=tf.data.Dataset.from_tensor_slices((x,y))
data=data.repeat(total_repeats)
data=data.batch(batch_size)
return data
def mk_iterators(train_fn,val_fn):
train_data=train_fn()
val_data=val_fn()
it=tf.data.Iterator.from_structure(train_data.output_types,train_data.output_shapes)
val_init=it.make_initializer(val_data)
train_init=it.make_initializer(train_data)
x,y=it.get_next()
return x,y,val_init,train_init
x,y,val_init,train_init=mk_iterators(train_fn,val_fn)
input_=Input(tensor=x)
labs=Input(tensor=y)
net=Dense(24,activation='relu')(input_)
net=Dense(24,activation='relu')(net)
output=Dense(1,activation='sigmoid')(net)
model=k.Model(input_,output)
loss=tf.losses.mean_squared_error(labs,output)
model.add_loss(loss)
model.compile('adam')
sess.run(train_init)
model.fit(steps_per_epoch=steps_epoch,epochs=total_epochs,verbose=0)
# swap datasets and run evaluation step
sess.run(val_init)
model.evaluate(steps=1)
# model.predict() will not work, but you can just call the final layer to get predictions
# (call x and y too if you want to iterate to the next batch for prediction):
sess.run({"input_":x,"OUTPUT":output,"LABS":y})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment