This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define the training inputs | |
def get_train_inputs(batch_size, mnist_data): | |
"""Return the input function to get the training data. | |
Args: | |
batch_size (int): Batch size of training iterator that is returned | |
by the input function. | |
mnist_data (Object): Object holding the loaded mnist data. | |
Returns: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class IteratorInitializerHook(tf.train.SessionRunHook): | |
"""Hook to initialise data iterator after Session is created.""" | |
def __init__(self): | |
super(IteratorInitializerHook, self).__init__() | |
self.iterator_initializer_func = None | |
def after_create_session(self, session, coord): | |
"""Initialise the iterator after the session has been created.""" | |
self.iterator_initializer_func(session) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Script to illustrate usage of tf.estimator.Estimator in TF v1.3""" | |
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import input_data as mnist_data | |
from tensorflow.contrib import slim | |
from tensorflow.contrib.learn import ModeKeys | |
from tensorflow.contrib.learn import learn_runner | |
# Show debugging output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Script to illustrate inference of a trained tf.estimator.Estimator. | |
NOTE: This is dependent on mnist_estimator.py which defines the model. | |
mnist_estimator.py can be found at: | |
https://gist.github.com/peterroelants/9956ec93a07ca4e9ba5bc415b014bcca | |
""" | |
import numpy as np | |
import skimage.io | |
import tensorflow as tf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
return tf.estimator.Estimator( | |
model_fn=model_fn, | |
config=config, | |
params=params, | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tf.estimator.train_and_evaluate(model_estimator, train_spec, eval_spec) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
dataset = tf.data.Dataset.from_tensor_slices(mnist_data) | |
dataset = dataset.shuffle( | |
buffer_size=1000, reshuffle_each_iteration=True | |
).repeat(count=None).batch(batch_size) |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2.0.0-dev20190205 | |
2019-02-05 18:09:28.956373: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA | |
2019-02-05 18:09:28.959894: I tensorflow/stream_executor/platform/default/dso_loader.cc:161] successfully opened CUDA library libcuda.so.1 locally | |
2019-02-05 18:09:29.161041: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1010] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero | |
2019-02-05 18:09:29.168088: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1010] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero | |
2019-02-05 18:09:29.168686: I tensorflow/compiler/xla/service/service.cc:162] XLA service 0x1a03ef0 executing computations on platform CUDA. Devices: | |
2019-02-05 18:09:29.168697: I tensorflow/compiler/xla/service/service.cc:169] StreamExecuto |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.