Skip to content

Instantly share code, notes, and snippets.

@aymericdelab
Last active October 15, 2019 11:06
Show Gist options
  • Save aymericdelab/aa5a5bcefd90964a21751c6c54f31439 to your computer and use it in GitHub Desktop.
Save aymericdelab/aa5a5bcefd90964a21751c6c54f31439 to your computer and use it in GitHub Desktop.
example of train_input_fn method for Tensorflow Sagemaker Estimator
def train_input_fn(training_dir, hyperparameters):
#training directory is the path to our data in S3 bucket
path=os.path.join(training_dir,'train.json')
with open(path,'rb') as f:
train=json.load(f)
X_train=np.array(train['images'],dtype=np.float64)
y_train=np.array(train['labels'],dtype=np.int64)
batch_size=hyperparameters.get('batch_size',32)
# returns tuples of features and labels
return tf.estimator.inputs.numpy_input_fn(
x={"x": X_train},
y=y_train,
batch_size=batch_size,
num_epochs=None,
shuffle=True)()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment