Skip to content

Instantly share code, notes, and snippets.

@devforfu
Created April 18, 2018 05:47
Show Gist options
  • Save devforfu/1d0a49b6b596e7a42678b7a985b4e4d1 to your computer and use it in GitHub Desktop.
Save devforfu/1d0a49b6b596e7a42678b7a985b4e4d1 to your computer and use it in GitHub Desktop.
DNNClassifier.fit_generator()
def fit_generator(self, generator, epochs, batches_per_epoch,
validation_data=None, callbacks=None):
"""
Fits model with generator yielding batches of (x, y) pairs.
The generator is expected to indefinitely generate samples from
training set. Therefore, we need a "hint" how much times we would like
to call generator's `next()` method during single epoch. For this
purpose `batches_per_epoch` parameter is used.
"""
if self._session is not None:
self._session.close()
graph = self._graph
with graph.as_default():
x, y, training, dropout = tf.get_collection('inputs')
loss, accuracy = tf.get_collection('metrics')
logits, training_op = tf.get_collection('training')
init = tf.global_variables_initializer()
self._saver = tf.train.Saver()
self._session = session = tf.Session(graph=graph)
monitor = CallbacksGroup(callbacks or [])
monitor.set_model(self)
monitor.on_start_training()
init.run(session=session)
for epoch in range(1, epochs + 1):
if self.stop_training:
break
epoch_loss = 0.0
for batch_index in range(batches_per_epoch):
x_batch, y_batch = next(generator)
feed = {x: x_batch, y: y_batch, training: True, dropout: 0.5}
_, batch_loss = session.run([training_op, loss], feed)
epoch_loss += batch_loss
monitor.on_batch(epoch, batch_index)
epoch_loss /= batches_per_epoch
metrics = {'train_loss': epoch_loss}
if validation_data:
x_valid, y_valid = validation_data
feed = {x: x_valid, y: y_valid, training: False, dropout: 0}
val_acc, val_loss = session.run([accuracy, loss], feed)
metrics['val_loss'] = val_loss
metrics['val_acc'] = val_acc
monitor.on_epoch(epoch, **metrics)
monitor.on_end_training()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment