Created
November 14, 2018 17:39
-
-
Save mtik00/b03c9011d9c5bbdc4e6c8b1791e980f2 to your computer and use it in GitHub Desktop.
Thread pool in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
This sample script is an example use of the threading pool to process a | |
sequence of items. | |
This is both python2 and python3 compatible. | |
""" | |
from __future__ import print_function | |
import time | |
from random import SystemRandom | |
from threading import Thread | |
try: | |
from queue import Empty, Full, Queue | |
except ImportError: | |
from Queue import Empty, Full, Queue | |
def work_function(item): | |
""" | |
This is a function that all workers will call to process an item. | |
""" | |
time.sleep(SystemRandom().random() * 10) | |
class WorkerThread(Thread): | |
""" | |
This is a single thread that it responsible for popping items off of the | |
queue and processing them. | |
""" | |
def __init__(self, queue, *args, **kwargs): | |
super(WorkerThread, self).__init__(*args, **kwargs) | |
self.queue = queue | |
self._stop_processing = False | |
self.processed_commit_count = 0 | |
def __str__(self): | |
return self.name | |
def stop_processing(self): | |
self._stop_processing = True | |
def run(self): | |
""" | |
Keep popping items from the queue and running them through our processor. | |
""" | |
print('%s starting' % self) | |
while not self._stop_processing: | |
try: | |
item = self.queue.get_nowait() | |
except Empty: | |
print(u'%s: !!! STARVING !!!' % self) | |
time.sleep(0.5) | |
continue | |
self.process_item(item) | |
print(u'%s: thread complete' % self) | |
def process_item(self, item): | |
""" | |
Process a single item. | |
We don't call ``work_function()`` directly so we can do some logging. | |
""" | |
print(u'%s: doing work on %s' % (self, item)) | |
work_function(item) | |
self.processed_commit_count += 1 | |
print(u'%s: done doing work on %s' % (self, item)) | |
class ThreadPool(object): | |
def __init__(self, threads=4, queue_size_factor=2): | |
""" | |
This object is responsible for creating and maintaining all of the | |
threads in the pool. | |
:param int threads: The total number of threads to create. | |
:param int queue_size_factor: We use a limited Queue size to keep memory | |
resources low. This parameter changes the factor used based on the | |
number of threads created. | |
""" | |
self.num_threads = threads | |
self._stop_processing = False | |
self.queue = Queue(maxsize=threads * queue_size_factor) | |
self.threads = [] | |
self.queued_item_count = 0 | |
def run(self): | |
""" | |
Spin up our threads, populate the queue, wait for the work to be finished, | |
join our threads, and make a quick check to see if we missed anything. | |
""" | |
self.threads = tuple(WorkerThread(queue=self.queue) for _ in range(self.num_threads)) | |
list(map(lambda x: x.start(), self.threads)) | |
print(u'%d worker threads started' % len(self.threads)) | |
# These are blocking operations | |
self.populate_queue() | |
self.wait_for_empty_queue() | |
# At this point, all of the items have been assigned to workers. | |
# It's safe to tell the threads to stop when they're done processing | |
# their current item. | |
self.stop() | |
# Now it's safe to ``join`` our threads. | |
self.join() | |
# Make sure we didn't screw anything up | |
self.check_count() | |
def populate_queue(self): | |
""" | |
*Slowly* put items in to the queue. Keep trying until we don't have any | |
more items. | |
We have a limited sized queue with a *limitless* number of items. | |
Therefore, we keep trying to add more items to the queue if the queue | |
tells us it's full. | |
""" | |
items = range(20) # Fake list of items | |
for item in items: | |
while not self._stop_processing: | |
try: | |
self.queue.put(item, True, 1) | |
print(u'added %r to queue' % item) | |
self.queued_item_count += 1 | |
break | |
except Full: | |
print(u"can't add %r to queue (Full)" % item) | |
time.sleep(0.1) | |
except: | |
self.stop() | |
raise | |
print(u'all %d items added to the queue' % self.queued_item_count) | |
def stop(self): | |
""" | |
This can be used to ask all of the threads to not grab another item from | |
the queue. This does not kill the threads! | |
""" | |
self._stop_processing = True | |
list(map(lambda x: x.stop_processing(), self.threads)) | |
def wait_for_empty_queue(self): | |
""" | |
Wait for the queue to empty out or someone asks us to stop. | |
""" | |
try: | |
while not (self.queue.empty() or self._stop_processing): | |
time.sleep(1) | |
except: | |
raise | |
def join(self): | |
""" | |
We must `join` our threads or they may not finish up what they were | |
doing. Be nice! | |
""" | |
print(u'joining threads') | |
for thread in self.threads: | |
thread.join() | |
print(u'...finished joining threads') | |
def check_count(self): | |
""" | |
This is a very simple test to see whether the threads *think* they | |
processed all of the items. | |
""" | |
total_processed = sum([thread.processed_commit_count for thread in self.threads]) | |
if total_processed != self.queued_item_count: | |
print(u'# queued items: %d' % self.queued_item_count) | |
print(u'# processed items: %d' % total_processed) | |
print(u'# queued items != # processed items') | |
raise Exception(u'# queued items != # processed items') | |
print(u'item counts are equal') | |
def main(): | |
t_start = time.time() | |
pool = ThreadPool(threads=5) | |
pool.run() | |
minutes, seconds = divmod(time.time() - t_start, 60) | |
hours, minutes = divmod(minutes, 60) | |
days, hours = divmod(hours, 24) | |
total = "%d days, %d:%02d:%02d" % (days, hours, minutes, seconds) | |
print(u'processed %d items in %s' % (pool.queued_item_count, total)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment