Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ask/186785 to your computer and use it in GitHub Desktop.
Save ask/186785 to your computer and use it in GitHub Desktop.
--- Lib/multiprocessing/pool.py 2009-02-06 15:42:30.000000000 +0100
+++ /opt/devel/celery/celery/_pool.py 2009-09-14 19:24:36.000000000 +0200
@@ -42,9 +42,10 @@
# Code run by worker processes
#
-def worker(inqueue, outqueue, initializer=None, initargs=()):
+def worker(inqueue, outqueue, ackqueue, initializer=None, initargs=()):
put = outqueue.put
get = inqueue.get
+ ack = ackqueue.put
if hasattr(inqueue, '_writer'):
inqueue._writer.close()
outqueue._reader.close()
@@ -64,6 +65,7 @@
break
job, i, func, args, kwds = task
+ ack((job, i))
try:
result = (True, func(*args, **kwds))
except Exception, e:
@@ -84,6 +86,7 @@
self._setup_queues()
self._taskqueue = Queue.Queue()
self._cache = {}
self._state = RUN
if processes is None:
@@ -96,7 +99,8 @@
for i in range(processes):
w = self.Process(
target=worker,
- args=(self._inqueue, self._outqueue, initializer, initargs)
+ args=(self._inqueue, self._outqueue, self._ackqueue,
+ initializer, initargs)
)
self._pool.append(w)
w.name = w.name.replace('Process', 'PoolWorker')
@@ -111,6 +115,14 @@
self._task_handler._state = RUN
self._task_handler.start()
+ self._ack_handler = threading.Thread(
+ target=Pool._handle_ack,
+ args=(self._ackqueue, self._ackqueue._reader.recv, self._cache)
+ )
+ self._ack_handler.daemon = True
+ self._ack_handler._state = RUN
+ self._ack_handler.start()
+
self._result_handler = threading.Thread(
target=Pool._handle_results,
args=(self._outqueue, self._quick_get, self._cache)
@@ -119,9 +131,11 @@
self._result_handler._state = RUN
self._result_handler.start()
+
self._terminate = Finalize(
self, self._terminate_pool,
- args=(self._taskqueue, self._inqueue, self._outqueue, self._pool,
+ args=(self._taskqueue, self._inqueue, self._outqueue,
+ self._ackqueue, self._pool, self._ack_handler,
self._task_handler, self._result_handler, self._cache),
exitpriority=15
)
@@ -130,6 +144,7 @@
from multiprocessing.queues import SimpleQueue
self._inqueue = SimpleQueue()
self._outqueue = SimpleQueue()
+ self._ackqueue = SimpleQueue()
self._quick_put = self._inqueue._writer.send
self._quick_get = self._outqueue._reader.recv
@@ -183,12 +198,13 @@
for i, x in enumerate(task_batches)), result._set_length))
return (item for chunk in result for item in chunk)
- def apply_async(self, func, args=(), kwds={}, callback=None):
+ def apply_async(self, func, args=(), kwds={},
+ callback=None, ack_callback=None):
'''
Asynchronous equivalent of `apply()` builtin
'''
assert self._state == RUN
- result = ApplyResult(self._cache, callback)
+ result = ApplyResult(self._cache, callback, ack_callback=ack_callback)
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
return result
@@ -251,14 +267,61 @@
debug('task handler exiting')
@staticmethod
+ def _handle_ack(ackqueue, get, cache):
+ thread = threading.current_thread()
+
+ while 1:
+ try:
+ task = get()
+ except (IOError, EOFError), exc:
+ debug('ack handler got %s -- exiting',
+ exc.__class__.__name__)
+
+ if thread._state:
+ assert thread._state == TERMINATE
+ debug('ack handler found thread._state=TERMINATE')
+ break
+
+ if task is None:
+ debug('ack handler got sentinel')
+ break
+
+ job, i = task
+ try:
+ cache[job]._ack(i)
+ except KeyError:
+ pass
+
+ while cache and thread._state != TERMINATE:
+ try:
+ task = get()
+ except (IOError, EOFError), exc:
+ debug('ack handler got %s -- exiting',
+ exc.__class__.__name__)
+ return
+
+ if task is None:
+ debug('ack handler ignoring extra sentinel')
+ continue
+ job, i = task
+ try:
+ cache[job]._ack(i)
+ except KeyError:
+ pass
+
+ debug('ack handler exiting: len(cache)=%s, thread._state=%s',
+ len(cache), thread._state)
+
+ @staticmethod
def _handle_results(outqueue, get, cache):
thread = threading.current_thread()
while 1:
try:
task = get()
- except (IOError, EOFError):
- debug('result handler got EOFError/IOError -- exiting')
+ except (IOError, EOFError), exc:
+ debug('result handler got %s -- exiting',
+ exc.__class__.__name__)
return
if thread._state:
@@ -279,8 +342,9 @@
while cache and thread._state != TERMINATE:
try:
task = get()
- except (IOError, EOFError):
- debug('result handler got EOFError/IOError -- exiting')
+ except (IOError, EOFError), exc:
+ debug('result handler got %s -- exiting',
+ exc.__class__.__name__)
return
if task is None:
@@ -351,14 +415,17 @@
time.sleep(0)
@classmethod
- def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
- task_handler, result_handler, cache):
+ def _terminate_pool(cls, taskqueue, inqueue, outqueue, ackqueue, pool,
+ ack_handler, task_handler, result_handler, cache):
# this is guaranteed to only be called once
debug('finalizing pool')
task_handler._state = TERMINATE
taskqueue.put(None) # sentinel
+ ack_handler._state = TERMINATE
+ ackqueue.put(None)
+
debug('helping task handler/workers to finish')
cls._help_stuff_finish(inqueue, task_handler, len(pool))
@@ -378,6 +445,9 @@
debug('joining result handler')
result_handler.join(1e100)
+ debug('joining ack handler')
+ ack_handler.join(1e100)
+
if pool and hasattr(pool[0], 'terminate'):
debug('joining pool workers')
for p in pool:
@@ -389,14 +459,19 @@
class ApplyResult(object):
- def __init__(self, cache, callback):
+ def __init__(self, cache, callback, ack_callback=None):
self._cond = threading.Condition(threading.Lock())
self._job = job_counter.next()
self._cache = cache
self._ready = False
+ self._accepted = False
self._callback = callback
+ self._ack_callback = ack_callback
cache[self._job] = self
+ def accepted(self):
+ return self._accepted
+
def ready(self):
return self._ready
@@ -421,6 +496,10 @@
else:
raise self._value
+ def _ack(self, i):
+ if self._ack_callback:
+ self._ack_callback()
+
def _set(self, i, obj):
self._success, self._value = obj
if self._callback and self._success:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment