Skip to content

Instantly share code, notes, and snippets.

@leechanwoo
Created May 10, 2018 07:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leechanwoo/a184656990684a05869c6db3186c5594 to your computer and use it in GitHub Desktop.
Save leechanwoo/a184656990684a05869c6db3186c5594 to your computer and use it in GitHub Desktop.
import tensorflow as tf
def input_fn():
dataset = tf.data.TFRecordDataset('dataset_path')
dataset = dataset.batch(10)
dataset = dataset.shuffle(6666)
dataset = dataset.repeat(10)
itr = dataset.make_one_shot_iterator()
features, label = itr.get_next()
data = [[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]], [[5, 4, 3, 2, 1], [5, 4, 3, 2, 1], [5, 4, 3, 2, 1]]]
labels = tf.constant([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], tf.int32)
features = tf.constant(data, tf.float32)
return {'features': features}, labels
def model_fn(features, labels, mode):
TRAIN = mode == tf.estimator.ModeKeys.TRAIN
EVAL = mode == tf.estimator.ModeKeys.EVAL
PRED = mode == tf.estimator.ModeKeys.PREDICT
inputs = features['features']
sequence_length = None
lstm = tf.nn.rnn_cell.BasicLSTMCell(num_units=100)
outputs, state = tf.nn.dynamic_rnn(cell=lstm, inputs=inputs, sequence_length=sequence_length, dtype=tf.float32)
o_batch, o_time, o_feat = tuple(outputs.shape)
output = tf.layers.dense(outputs[:, -1, :], 10)
if TRAIN:
loss = tf.losses.softmax_cross_entropy(labels, output)
train_op = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)
estimator_spec = tf.estimator.EstimatorSpec(
mode=mode,
train_op=train_op,
loss=loss)
elif EVAL:
loss = tf.losses.softmax_cross_entropy(labels, output)
eval_metric_ops = tf.metrics.accuracy(labels, output)
estimator_spec = tf.estimator.EstimatorSpec(
mode=mode,
eval_metric_ops=eval_metric_ops,
loss=loss)
elif PRED:
predictions = tf.argmax(output, axis=1)
estimator_spec = tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions)
else:
raise Exception('none estimatorspec defined')
return estimator_spec
if __name__ == "__main__":
est = tf.estimator.Estimator(model_fn)
est.train(input_fn)
est.evaluation(input_fn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment