Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created February 3, 2018 15:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/bd3d69b19f8f681fa019543f48650b0d to your computer and use it in GitHub Desktop.
Save NMZivkovic/bd3d69b19f8f681fa019543f48650b0d to your computer and use it in GitHub Desktop.
# Define train function
def train_function(inputs, outputs, batch_size):
dataset = tf.data.Dataset.from_tensor_slices((dict(inputs), outputs))
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
# Train the Model.
classifier.train(
input_fn=lambda:train_function(train_x, train_y, 100),
steps=1000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment