Skip to content

Instantly share code, notes, and snippets.

@persiyanov
Created May 28, 2018 11:59
Show Gist options
  • Save persiyanov/0a8ca3d9091775bd136cfe6e4674e376 to your computer and use it in GitHub Desktop.
Save persiyanov/0a8ca3d9091775bd136cfe6e4674e376 to your computer and use it in GitHub Desktop.
multistream api _train_epoch
def _train_epoch(self, data_iterables, cur_epoch=0, total_examples=None,
total_words=None, queue_factor=2, report_delay=1.0):
"""Train one epoch."""
_reset_performance_metrics()
job_queue = Queue(maxsize=queue_factor * self.workers)
progress_queue = Queue(maxsize=(queue_factor + 1) * self.workers)
workers = [
threading.Thread(
target=self._worker_loop,
args=(job_queue, progress_queue,))
for _ in xrange(self.workers)
]
workers.extend(
threading.Thread(
target=self._job_producer,
args=(data_iterable, job_queue),
kwargs={'cur_epoch': cur_epoch, 'total_examples': total_examples, 'total_words': total_words}
) for data_iterable in data_iterables
)
for thread in workers:
thread.daemon = True # make interrupting the process with ctrl+c easier
thread.start()
trained_word_count, raw_word_count, job_tally = self._log_epoch_progress(
progress_queue, job_queue, cur_epoch=cur_epoch, total_examples=total_examples, total_words=total_words,
report_delay=report_delay)
return trained_word_count, raw_word_count, job_tally
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment