Skip to content

Instantly share code, notes, and snippets.

@ask
Created September 15, 2009 12:26
Show Gist options
  • Save ask/187254 to your computer and use it in GitHub Desktop.
Save ask/187254 to your computer and use it in GitHub Desktop.
Index: Lib/multiprocessing/pool.py
===================================================================
--- Lib/multiprocessing/pool.py (revision 74797)
+++ Lib/multiprocessing/pool.py (working copy)
@@ -12,11 +12,14 @@
# Imports
#
+import os
+import errno
import threading
import Queue
import itertools
import collections
import time
+from signal import signal, SIGUSR1
from multiprocessing import Process, cpu_count, TimeoutError
from multiprocessing.util import Finalize, debug
@@ -42,9 +45,21 @@
# Code run by worker processes
#
-def worker(inqueue, outqueue, initializer=None, initargs=()):
+class TimeLimitExceeded(Exception):
+ """The time limit has been exceeded and the job has been terminated."""
+
+class SoftTimeLimitExceeded(Exception):
+ """The soft time limit has been exceeded. This exception
+ is raised to give the job a chance to clean up."""
+
+def soft_timeout_sighandler(signum, frame):
+ raise SoftTimeLimitExceeded()
+
+def worker(inqueue, outqueue, ackqueue, initializer=None, initargs=()):
+ pid = os.getpid()
put = outqueue.put
get = inqueue.get
+ ack = ackqueue.put
if hasattr(inqueue, '_writer'):
inqueue._writer.close()
outqueue._reader.close()
@@ -52,6 +67,8 @@
if initializer is not None:
initializer(*initargs)
+ signal(SIGUSR1, soft_timeout_sighandler)
+
while 1:
try:
task = get()
@@ -64,6 +81,7 @@
break
job, i, func, args, kwds = task
+ ack((job, i, time.time(), pid))
try:
result = (True, func(*args, **kwds))
except Exception, e:
@@ -80,9 +98,14 @@
'''
Process = Process
- def __init__(self, processes=None, initializer=None, initargs=()):
+ def __init__(self, processes=None, initializer=None, initargs=(),
+ timeout=None, soft_timeout=None):
self._setup_queues()
self._taskqueue = Queue.Queue()
+ self.timeout = timeout
+ self.soft_timeout = soft_timeout
+ self._initializer = initializer
+ self._initargs = initargs
self._cache = {}
self._state = RUN
@@ -95,16 +118,7 @@
if initializer is not None and not hasattr(initializer, '__call__'):
raise TypeError('initializer must be a callable')
- self._pool = []
- for i in range(processes):
- w = self.Process(
- target=worker,
- args=(self._inqueue, self._outqueue, initializer, initargs)
- )
- self._pool.append(w)
- w.name = w.name.replace('Process', 'PoolWorker')
- w.daemon = True
- w.start()
+ self._pool = [self._add_worker() for i in range(processes)]
self._task_handler = threading.Thread(
target=Pool._handle_tasks,
@@ -114,6 +128,31 @@
self._task_handler._state = RUN
self._task_handler.start()
+ # Thread processing acknowledgements form the ackqueue.
+ self._ack_handler = threading.Thread(
+ target=Pool._handle_ack,
+ args=(self._ackqueue, self._quick_get_ack, self._cache)
+ )
+ self._ack_handler.daemon = True
+ self._ack_handler._state = RUN
+ self._ack_handler.start()
+
+ # Thread killing timedout jobs.
+ if self.timeout or self.soft_timeout:
+ self._timeout_handler_stopped = threading.Event()
+ self._timeout_handler = threading.Thread(
+ target=Pool._handle_timeouts,
+ args=(self, self._timeout_handler_stopped, self._cache,
+ self.soft_timeout, self.timeout)
+ )
+ self._timeout_handler.deamon = True
+ self._timeout_handler._state = RUN
+ self._timeout_handler.start()
+ else:
+ self._timeout_handler_stopped = None
+ self._timeout_handler = None
+
+ # Thread processing results in the outqueue.
self._result_handler = threading.Thread(
target=Pool._handle_results,
args=(self._outqueue, self._quick_get, self._cache)
@@ -124,17 +163,37 @@
self._terminate = Finalize(
self, self._terminate_pool,
- args=(self._taskqueue, self._inqueue, self._outqueue, self._pool,
- self._task_handler, self._result_handler, self._cache),
+ args=(self._taskqueue, self._inqueue, self._outqueue,
+ self._ackqueue, self._pool, self._ack_handler,
+ self._task_handler, self._result_handler, self._cache,
+ self._timeout_handler,
+ self._timeout_handler_stopped),
exitpriority=15
)
+
+ def _add_worker(self):
+ """Add another worker to the pool."""
+ w = self.Process(
+ target=worker,
+ args=(self._inqueue, self._outqueue, self._ackqueue,
+ self._initializer, self._initargs)
+ )
+ w.name = w.name.replace('Process', 'PoolWorker')
+ w.daemon = True
+ w.start()
+ return w
+ def grow(self, n=1):
+ self._pool.extend([self._add_worker() for i in range(n)])
+
def _setup_queues(self):
from .queues import SimpleQueue
self._inqueue = SimpleQueue()
self._outqueue = SimpleQueue()
+ self._ackqueue = SimpleQueue()
self._quick_put = self._inqueue._writer.send
self._quick_get = self._outqueue._reader.recv
+ self._quick_get_ack = self._ackqueue._reader.recv
def apply(self, func, args=(), kwds={}):
'''
@@ -186,12 +245,25 @@
for i, x in enumerate(task_batches)), result._set_length))
return (item for chunk in result for item in chunk)
- def apply_async(self, func, args=(), kwds={}, callback=None):
+ def apply_async(self, func, args=(), kwds={},
+ callback=None, accept_callback=None):
'''
- Asynchronous equivalent of `apply()` builtin
+ Asynchronous equivalent of `apply()` builtin.
+
+ Callback is called when the functions return value is ready.
+ The accept callback is called when the job is accepted to be executed.
+
+ Simplified the flow is like this:
+
+ >>> if accept_callback:
+ ... accept_callback()
+ >>> retval = func(*args, **kwds)
+ >>> if callback:
+ ... callback(retval)
+
'''
assert self._state == RUN
- result = ApplyResult(self._cache, callback)
+ result = ApplyResult(self._cache, callback, accept_callback)
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
return result
@@ -240,7 +312,6 @@
else:
debug('task handler got sentinel')
-
try:
# tell result handler to finish when cache is empty
debug('task handler sending sentinel to result handler')
@@ -256,14 +327,138 @@
debug('task handler exiting')
@staticmethod
+ def _handle_timeouts(pool, sentinel_event, cache, t_soft, t_hard):
+ thread = threading.current_thread()
+ processes = pool._pool
+ dirty = set()
+
+ def _process_by_pid(pid):
+ for index, process in enumerate(processes):
+ if process.pid == pid:
+ return process, index
+ return (None, None)
+
+ def _pop_by_pid(pid):
+ process, index = _process_by_pid(pid)
+ if not process:
+ return
+ p = processes.pop(index)
+ assert p is process
+ return process
+
+ def _timed_out(start, timeout):
+ if not start or not timeout:
+ return False
+ if time.time() >= start + timeout:
+ return True
+
+ def _on_soft_timeout(job, i):
+ debug('soft time limit exceeded for %i', i)
+ process, _index = _process_by_pid(job._accept_pid)
+ if not process:
+ return
+
+ try:
+ os.kill(job._accept_pid, SIGUSR1)
+ except OSError, exc:
+ if exc.errno == errno.ESRCH:
+ pass
+ else:
+ raise
+
+ dirty.add(i)
+
+ def _on_hard_timeout(job, i):
+ debug('hard time limit exceeded for %i', i)
+ # Remove from _pool
+ process = _pop_by_pid(job._accept_pid)
+ # Remove from cache and set return value to an exception.
+ job._set(i, (False, TimeLimitExceeded()))
+ if not process:
+ return
+ # Terminate the process and create a new one.
+ process.terminate()
+ pool.grow(1)
+
+ # Inner-loop
+ while 1:
+ if sentinel_event.isSet():
+ debug('timeout handler recieved sentinel.')
+ break
+
+ # Remove dirty items not in cache anymore.
+ if dirty:
+ dirty = set(k for k in dirty if k in cache)
+
+ for i, job in cache.items():
+ ack_time = job._time_accepted
+ if _timed_out(ack_time, t_hard):
+ _on_hard_timeout(job, i)
+ elif i not in dirty and _timed_out(ack_time, t_soft):
+ _on_soft_timeout(job, i)
+
+ time.sleep(1) # Don't waste CPU cycles.
+
+ debug('timeout handler exiting')
+
+ @staticmethod
+ def _handle_ack(ackqueue, get, cache):
+ thread = threading.current_thread()
+
+ while 1:
+ try:
+ task = get()
+ except (IOError, EOFError), exc:
+ debug('ack handler got %s -- exiting',
+ exc.__class__.__name__)
+
+ if thread._state:
+ assert thread._state == TERMINATE
+ debug('ack handler found thread._state=TERMINATE')
+ break
+
+ if task is None:
+ debug('ack handler got sentinel')
+ break
+
+ job, i, time_accepted, pid = task
+ try:
+ cache[job]._ack(time_accepted, pid)
+ except (KeyError, AttributeError):
+ # Object gone, or doesn't support _ack (e.g. IMapIterator)
+ pass
+
+ while cache and thread._state != TERMINATE:
+ try:
+ task = get()
+ except (IOError, EOFError), exc:
+ debug('ack handler got %s -- exiting',
+ exc.__class__.__name__)
+ return
+
+ if task is None:
+ debug('result handler ignoring extra sentinel')
+ continue
+
+ job, i = task
+ try:
+ cache[job]._ack(i)
+ except KeyError:
+ pass
+
+ debug('ack handler exiting: len(cache)=%s, thread._state=%s',
+ len(cache), thread._state)
+
+ @staticmethod
def _handle_results(outqueue, get, cache):
thread = threading.current_thread(
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment