Skip to content

Instantly share code, notes, and snippets.

@trcook
Created April 3, 2018 19:29
Show Gist options
  • Save trcook/a0ccdff16a5c58f83ce37423dd8f3e4d to your computer and use it in GitHub Desktop.
Save trcook/a0ccdff16a5c58f83ce37423dd8f3e4d to your computer and use it in GitHub Desktop.
basic example using tf estimators as of 1.4
def input_fn(total_repeats=int(1e8),batch_size=100):
x=tf.constant([
[[0,1]],
[[0,0]],
[[1,1]],
[[1,0]]
],dtype=tf.float32)
y=tf.constant([
[[1]],
[[0]],
[[0]],
[[1]]
],dtype=tf.int32)
data=tf.data.Dataset.from_tensor_slices((x,y))
data=data.repeat(total_repeats)
data=data.batch(batch_size)
it=data.make_one_shot_iterator()
x,y=it.get_next()
return x,y
def val_fn(total_repeats=1,batch_size=4):
x=tf.constant([
[[0,1]],
[[0,0]],
[[1,1]],
[[1,0]]
],dtype=tf.float32)
y=tf.constant([
[[1]],
[[0]],
[[0]],
[[1]]
],dtype=tf.int32)
data=tf.data.Dataset.from_tensor_slices((x,y))
data=data.repeat(total_repeats)
data=data.batch(batch_size)
it=data.make_one_shot_iterator()
x,y=it.get_next()
return x,y
def model_fn(features,labels,mode,params):
net=features
print(mode)
if mode == tf.estimator.ModeKeys.TRAIN:
is_training=True
else:
is_training=False
normalizer_params={'is_training': is_training, 'decay': 0.9, 'updates_collections': None}
with sm.arg_scope([sm.fully_connected, sm.conv2d],
normalizer_fn=sm.batch_norm,
normalizer_params= normalizer_params,
weights_regularizer=sm.l2_regularizer(float(1e-04)),
weights_initializer=layers.xavier_initializer(),
activation_fn=sm.nn.relu
) as asc: pass
with sm.arg_scope(asc):
net=sm.fully_connected(net,24)
net=sm.fully_connected(net,24)
yhat=sm.fully_connected(net,1)
global_step=tf.train.get_or_create_global_step()
adam=None
metrics=None
loss=None
if is_training or mode==tf.estimator.ModeKeys.EVAL:
loss=tf.losses.mean_squared_error(tf.cast(labels,dtype=tf.float32),tf.cast(yhat,dtype=tf.float32))
loss=tf.losses.get_total_loss()
adam=tf.train.AdamOptimizer(learning_rate=.01)
adam=adam.minimize(loss,global_step=global_step)
metrics={'mae':tf.metrics.mean_absolute_error(tf.cast(labels,dtype=tf.float32),tf.cast(yhat,dtype=tf.float32))}
return tf.estimator.EstimatorSpec(mode=mode,predictions=yhat,loss=loss,train_op=adam,eval_metric_ops=metrics)
esto=tf.estimator.Estimator(model_fn)
esto.train(input_fn,max_steps=10000)
esto.evaluate(input_fn,steps=1)
list(esto.predict(val_fn,))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment