Skip to content

Instantly share code, notes, and snippets.

@crawles
Created March 5, 2019 02:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save crawles/0f583d6e96d702cb82a4d0f75729bdc3 to your computer and use it in GitHub Desktop.
Save crawles/0f583d6e96d702cb82a4d0f75729bdc3 to your computer and use it in GitHub Desktop.
# Use entire batch since this is such a small dataset.
NUM_EXAMPLES = len(y_train)
def make_input_fn(X, y, n_epochs=None, shuffle=True):
def input_fn():
dataset = tf.data.Dataset.from_tensor_slices((dict(X), y))
if shuffle:
dataset = dataset.shuffle(NUM_EXAMPLES)
# For training, cycle thru dataset as many times as need (n_epochs=None).
dataset = dataset.repeat(n_epochs)
# In memory training doesn't use batching.
dataset = dataset.batch(NUM_EXAMPLES)
return dataset
return input_fn
# Training and evaluation input functions.
train_input_fn = make_input_fn(dftrain, y_train)
eval_input_fn = make_input_fn(dfeval, y_eval, shuffle=False, n_epochs=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment