Skip to content

Instantly share code, notes, and snippets.

@hnakagawa
Created July 6, 2018 11:49
Show Gist options
  • Save hnakagawa/e4d4a7b66a9754dd5f8325d8d3b5ecb9 to your computer and use it in GitHub Desktop.
Save hnakagawa/e4d4a7b66a9754dd5f8325d8d3b5ecb9 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')
def main():
def input(dataset):
return dataset.images, dataset.labels.astype(np.int32)
# Specify feature
feature_columns = [tf.feature_column.numeric_column("x", shape=[28, 28])]
# Build 2 layer DNN classifier
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf.train.AdamOptimizer(1e-4),
n_classes=10,
dropout=0.1,
model_dir="./tmp/mnist_model"
)
# Define the training inputs
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": input(mnist.train)[0]},
y=input(mnist.train)[1],
num_epochs=None,
batch_size=50,
shuffle=True
)
classifier.train(input_fn=train_input_fn, steps=1)
# Define the test inputs
test_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": input(mnist.test)[0]},
y=input(mnist.test)[1],
num_epochs=1,
shuffle=False
)
# Evaluate accuracy
accuracy_score = classifier.evaluate(input_fn=test_input_fn)["accuracy"]
print("\nTest Accuracy: {0:f}%\n".format(accuracy_score*100))
spec = tf.feature_column.make_parse_example_spec(
feature_columns
)
fn = tf.estimator.export\
.build_parsing_serving_input_receiver_fn(spec)
classifier.export_savedmodel(export_dir_base='models',
serving_input_receiver_fn=fn)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment