Skip to content

Instantly share code, notes, and snippets.

@shu-yusa
Last active June 14, 2018 16:16
Show Gist options
  • Save shu-yusa/0423efb98996e02a822a6cddb4dd1c18 to your computer and use it in GitHub Desktop.
Save shu-yusa/0423efb98996e02a822a6cddb4dd1c18 to your computer and use it in GitHub Desktop.
[TensorFlow] MirroredStrategyを用いて複数GPU計算を行う ref: https://qiita.com/shu-yusa/items/e93e934a14849541de78
distribution = tf.contrib.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=distribution)
classifier = tf.estimator.Estimator(model_fn=model_fn, config=config)
ValueError: dataset_fn() must return a tf.data.Dataset when using a DistributionStrategy.
class InputFnProvider:
def __init__(self, train_batch_size):
self.train_batch_size = train_batch_size
self.__load_data()
def __load_data(self):
# Load training and eval data
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
self.train_data = mnist.train.images # Returns np.array
self.train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
self.eval_data = mnist.test.images # Returns np.array
self.eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
def train_input_fn(self):
"""An input function for training"""
# Shuffle, repeat, and batch the examples.
dataset = tf.data.Dataset.from_tensor_slices(({"x": self.train_data}, self.train_labels))
dataset = dataset.shuffle(1000).repeat().batch(self.train_batch_size)
return dataset
def eval_input_fn(self):
"""An input function for evaluation or prediction"""
dataset = tf.data.Dataset.from_tensor_slices(({"x": self.eval_data}, self.eval_labels))
dataset = dataset.batch(1)
return dataset
# (中略)
# Train the model
mnist_classifier.train(
input_fn=input_fn_provider.train_input_fn,
steps=10000)
# Evaluate the model and print results
eval_results = mnist_classifier.evaluate(input_fn=input_fn_provider.eval_input_fn)
return dataset.make_one_shot_iterator().get_next()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment