Skip to content

Instantly share code, notes, and snippets.

@Echooff3
Last active January 3, 2018 20:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Echooff3/e0fbb868da9a02abc5b56c9e618a5c35 to your computer and use it in GitHub Desktop.
Save Echooff3/e0fbb868da9a02abc5b56c9e618a5c35 to your computer and use it in GitHub Desktop.
Contrived tf.estimator.DNNClassifier
import tensorflow as tf
import numpy as np
import sys
from tensorflow.python import debug as tf_debug
hooks = [tf_debug.LocalCLIDebugHook()]
tf.logging.set_verbosity(tf.logging.INFO)
trainX = np.array([[1,1,0,1],[0,0,1,0],[1,0,1,1],[0,0,1,1]])
labelX = np.array([[1],[0],[1],[0]])
num_classes = 2
feature_names = ['f1','f2','f3','f4']
feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names]
classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
n_classes=num_classes, #setting number of classes here
hidden_units=[10])
def input_fn():
def gen1(a,b):
print(a.shape, b.shape)
features = tf.split(a,4)
return dict(zip(feature_names, features)),b
dataset = (tf.data.Dataset.from_tensor_slices((trainX, labelX)).map(gen1))
dataset = dataset.repeat(8)
dataset = dataset.batch(32)
iterator = dataset.make_one_shot_iterator()
data, labels = iterator.get_next()
return data, labels
def input_fn_pred(in_arr):
def gen1(a):
features = tf.split(a,4)
return dict(zip(feature_names, features))
dataset = tf.data.Dataset.from_tensor_slices((in_arr)).map(gen1)
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()
return data, None
# check values
# next_batch = input_fn()
# with tf.Session() as sess:
# first_batch = sess.run(next_batch)
# print(first_batch)
# sys.exit(0)
classifier.train(input_fn=lambda: input_fn())
evaluate_result = classifier.evaluate(input_fn=lambda: input_fn())
print("Evaluation results")
for key in evaluate_result:
print(" {}, was: {}".format(key, evaluate_result[key]))
test_set = np.array([[1,1,1,1],[0,0,0,0],[0,0,1,1],[1,1,0,0]])
predict_results = classifier.predict(
input_fn=lambda: input_fn_pred(test_set))
for prediction in predict_results:
print prediction["class_ids"][0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment