Created
April 18, 2018 05:47
-
-
Save devforfu/1d0a49b6b596e7a42678b7a985b4e4d1 to your computer and use it in GitHub Desktop.
DNNClassifier.fit_generator()
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def fit_generator(self, generator, epochs, batches_per_epoch, | |
validation_data=None, callbacks=None): | |
""" | |
Fits model with generator yielding batches of (x, y) pairs. | |
The generator is expected to indefinitely generate samples from | |
training set. Therefore, we need a "hint" how much times we would like | |
to call generator's `next()` method during single epoch. For this | |
purpose `batches_per_epoch` parameter is used. | |
""" | |
if self._session is not None: | |
self._session.close() | |
graph = self._graph | |
with graph.as_default(): | |
x, y, training, dropout = tf.get_collection('inputs') | |
loss, accuracy = tf.get_collection('metrics') | |
logits, training_op = tf.get_collection('training') | |
init = tf.global_variables_initializer() | |
self._saver = tf.train.Saver() | |
self._session = session = tf.Session(graph=graph) | |
monitor = CallbacksGroup(callbacks or []) | |
monitor.set_model(self) | |
monitor.on_start_training() | |
init.run(session=session) | |
for epoch in range(1, epochs + 1): | |
if self.stop_training: | |
break | |
epoch_loss = 0.0 | |
for batch_index in range(batches_per_epoch): | |
x_batch, y_batch = next(generator) | |
feed = {x: x_batch, y: y_batch, training: True, dropout: 0.5} | |
_, batch_loss = session.run([training_op, loss], feed) | |
epoch_loss += batch_loss | |
monitor.on_batch(epoch, batch_index) | |
epoch_loss /= batches_per_epoch | |
metrics = {'train_loss': epoch_loss} | |
if validation_data: | |
x_valid, y_valid = validation_data | |
feed = {x: x_valid, y: y_valid, training: False, dropout: 0} | |
val_acc, val_loss = session.run([accuracy, loss], feed) | |
metrics['val_loss'] = val_loss | |
metrics['val_acc'] = val_acc | |
monitor.on_epoch(epoch, **metrics) | |
monitor.on_end_training() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment