Last active
August 29, 2015 14:13
-
-
Save kcsaff/6bc177f2531f43082d0b to your computer and use it in GitHub Desktop.
Noisy threadpool implementation that automatically handles multiple attempts & such.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import queue | |
import threading | |
from contextlib import contextmanager | |
from collections import Counter | |
import time | |
import traceback | |
class Warn(object): | |
""" | |
Class to automatically elide repeated messages or warnings. Since this will store all | |
messages received in a set, this is best used for a small set of possible warning messages | |
-- but see below, it's at least smart enough that you can ignore parameterization if you want. | |
`Warn` can take any print function, in this example we'll just store the warnings in a list. | |
>>> warnings = list() | |
>>> warn = Warn(warnings.append) | |
We can tell `Warn` to print a message only once | |
>>> warn(1, 'This is a repeated message') | |
True | |
And we'll find that it's been printed (and `Warn` told us so, by responding `True`). | |
>>> warnings | |
['This is a repeated message'] | |
If we tell it to print that again... | |
>>> warn(1, "This is a repeated message") | |
False | |
It won't be printed again -- only the once. | |
>>> warnings | |
['This is a repeated message'] | |
So now let's print a parameterized warning message. | |
>>> del warnings[:] | |
>>> warn(1, 'Critical error {errno}!!!', errno=32451) | |
True | |
>>> warnings | |
['Critical error 32451!!!'] | |
If we try this again with a different parameter, it will be ignored, because only the format string counts. | |
>>> warn(1, 'Critical error {errno}!!!', errno=77777) | |
False | |
>>> warnings | |
['Critical error 32451!!!'] | |
If you want it to treat these cases differently, you just need to format before handing it over to `Warn`. | |
""" | |
def __init__(self, print_function=print, final_warning_count_cutoff=2): | |
self._lock = threading.RLock() | |
self._final_warning_count_cutoff = final_warning_count_cutoff | |
self._warnings = Counter() | |
self._print = print_function | |
def __call__(self, count, warning, *args, **kwargs): | |
with self._lock: | |
self._warnings[warning] += 1 | |
if self._warnings[warning] <= count: | |
if args or kwargs: | |
warning = warning.format(*args, **kwargs) | |
if self._warnings[warning] == count >= self._final_warning_count_cutoff: | |
self._print('{0}: FINAL WARNING'.format(warning)) | |
else: | |
self._print(warning) | |
return True | |
else: | |
return False | |
class TaskSummary(object): | |
"""Keeps track of task status in a thread-safe way. | |
This verbosely prints task status including percentage complete and ETA in a thread-safe way, | |
so you know where your process is at in its long slog towards completion. | |
This can be configured to: | |
- log every N tasks submitted or completed | |
- log every N seconds | |
The name & message can be customized. | |
Here we see a `TaskSummary` with `milestone` set to 10, meaning it will only report every 10 submissions or completions. | |
>>> reports = list() | |
>>> ts = TaskSummary(milestone=10, maxtime=None, message='{finished}/{submitted}', print_function=reports.append) | |
>>> ts.submit(4) | |
>>> reports | |
[] | |
>>> ts.submit(7) | |
>>> reports | |
['00/11'] | |
>>> ts.submit(8) | |
>>> ts.finish(7) | |
>>> reports | |
['00/11'] | |
>>> ts.finish(5) | |
>>> reports | |
['00/11', '12/19'] | |
""" | |
DEFAULT_MESSAGE = '{name} {finished}/{submitted} ({finish_percentage}%) in {elapsed}. ETA: T-{eta}' | |
def __init__(self, milestone=1000, maxtime=10.0, name='Processed', message=None, print_function=print, time_function=time.time, report_on_completion=True): | |
""" | |
:param milestone: maximum number of submissions or completions between reports | |
:param maxtime: maximum number of seconds between reports | |
:param name: name to use in reports | |
:param message: report format: may include `name`, `finished`, `submitted`, `finish_percentage`, `elapsed`, `eta`, | |
`succeeded`, `failed`, `succeed_percentage`, `fail_percentage` | |
:param print_function: where to print this stuff | |
:param time_function: function returning number of elapsed seconds | |
""" | |
self._lock = threading.RLock() | |
self._submitted = 0 | |
self._finished = 0 | |
self._succeeded = 0 | |
self._failed = 0 | |
self._milestone = milestone if milestone is not None else 1e12 | |
self._time = time_function | |
self._first_time = self._time() | |
self._maxtime = maxtime if maxtime is not None else 1e9 | |
self._name = name | |
self._message = message or self.DEFAULT_MESSAGE | |
self._print = print_function or _null_print | |
self._report_on_completion = report_on_completion | |
# Goals | |
# Created as `None` here, but corrected in `_update_goals()` | |
self._next_submit_milestone = None | |
self._next_finish_milestone = None | |
self._next_time = None | |
self._update_goals() | |
@classmethod | |
def quiet(cls, milestone=None, maxtime=None, name='Processed', message=None, print_function=None, time_function=time.time, report_on_completion=False): | |
"""Return a very quiet version of TaskSummary. | |
Use this if you don't want your VerboseThreadPool to be very verbose, after all. | |
""" | |
return cls( | |
milestone=milestone, | |
maxtime=maxtime, | |
name=name, | |
message=message, | |
print_function=print_function, | |
time_function=time_function, | |
report_on_completion=report_on_completion | |
) | |
def submit(self, count=1): | |
"""Indicate additional task(s) added to queue. | |
:param count: rough number of tasks added. Alternatively, mark as > 1 to indicate addition of "harder" task. | |
""" | |
with self._update_context(): | |
self._submitted += count | |
def finish(self, count=1, success=True): | |
"""Indicate task(s) completed. | |
:param count: rough number of tasks completed. Alternatively, mark as > 1 to indicate completion of "harder" task. | |
""" | |
with self._update_context(): | |
self._finished += count | |
if success: | |
self._succeeded += count | |
else: | |
self._failed += count | |
def report(self): | |
"""Print current status""" | |
with self._lock: | |
self._print(self.get_status()) | |
self._next_submit_milestone = self._submitted + self._milestone | |
self._next_finish_milestone = min((self._finished + self._milestone, self._submitted)) | |
self._next_time = self._time() + self._maxtime | |
def get_status(self): | |
"""Get current formatted status as a string""" | |
with self._lock: | |
percentage = 100 * self._finished // self._submitted if self._submitted else 0 | |
time_seconds = int(self._time() - self._first_time) | |
time_amount = '{0}:{1:02}'.format(time_seconds // 60, time_seconds % 60) | |
time_eta_seconds = time_seconds * self._submitted // self._finished - time_seconds if self._finished else 0 | |
time_eta = '{0}:{1:02}'.format(time_eta_seconds // 60, time_eta_seconds % 60) | |
len_submitted = len(str(self._submitted)) | |
formatted_finished = str(self._finished).zfill(len_submitted) | |
formatted_succeeded = str(self._succeeded).zfill(len_submitted) | |
formatted_failed = str(self._failed).zfill(len_submitted) | |
succeed_percentage = 100 * self._succeeded // self._submitted if self._submitted else 0 | |
fail_percentage = 100 * self._failed // self._submitted if self._submitted else 0 | |
return self._message.format( | |
name=self._name, | |
finished=formatted_finished, | |
submitted=self._submitted, | |
finish_percentage=percentage, | |
elapsed=time_amount, | |
eta=time_eta, | |
succeeded=formatted_succeeded, | |
failed=formatted_failed, | |
succeed_percentage=succeed_percentage, | |
fail_percentage=fail_percentage, | |
) | |
def _update_goals(self): | |
"""Update targets for next report.""" | |
self._next_submit_milestone = self._submitted + self._milestone | |
self._next_finish_milestone = min((self._finished + self._milestone, self._submitted)) | |
self._next_time = self._time() + self._maxtime | |
@contextmanager | |
def _update_context(self): | |
"""context manager to make sure we print reports when an update suggests we should.""" | |
with self._lock: | |
old_finished = self._finished | |
old_submitted = self._submitted | |
yield | |
if (old_finished < self._next_finish_milestone <= self._finished) \ | |
or (old_submitted < self._next_submit_milestone <= self._submitted) \ | |
or (self._report_on_completion and self._finished >= self._submitted) \ | |
or (self._time() >= self._next_time): | |
self.report() | |
class VerboseThreadPool(object): | |
"""Thread pool that very noisily keeps you updated on its progress. | |
This is intended to be used when you have a process that creates a number of tasks, and then waits for all those | |
tasks to be completed by a pool of worker threads. It keeps you updated on how much work it's done so far, and | |
how much it thinks it still needs to do. | |
As a bonus, it can automatically attempt tasks multiple times, failing them only after a maximum number of attempts | |
have been made. | |
This can be used as a contextmanager: if so, it will wait for all tasks to complete on exit. | |
>>> outputs = list() | |
>>> with VerboseThreadPool("test pool", 10, print_function=None) as pool: | |
... for i in range(100): | |
... pool.submit(3, outputs.append, i) | |
>>> len(outputs) | |
100 | |
>>> sorted(outputs)[:10] | |
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | |
""" | |
def __init__(self, name, max_workers, worker_timeout=1.0, summary_milestone=1000, summary_time=10.0, summary='auto', print_function=print): | |
""" | |
:param name: The name to be used in messages and for identifying threads. | |
:param max_workers: The maximum number of worker threads to create. | |
:param worker_timeout: Time after which a worker dies if no new tasks are available. | |
:param summary_milestone: Number of submissions or completions after which to report status | |
:param summary_time: Time in seconds after which to report status | |
:oaram summary: full `TaskSummary` object that if present, overrides convenience parameters `summary_milestone` and `summary_time` | |
""" | |
self.lock = threading.RLock() | |
self.name = name | |
self.max_workers = max_workers | |
self.workers = set() | |
self.queue = queue.Queue() | |
self.warn = Warn() | |
if summary == 'auto': | |
self.summary = TaskSummary(name=self.name, milestone=summary_milestone, maxtime=summary_time, print_function=print_function) | |
elif summary is None: | |
self.summary = TaskSummary.quiet() | |
else: | |
self.summary = summary | |
self.worker_timeout = worker_timeout | |
self._is_shutdown = False | |
self._print = print_function or _null_print | |
self._workers_created = 0 | |
def submit(self, max_attempts, fun, *args, **kwargs): | |
"""Submit a task to be completed on the threadpool. | |
:param max_attempts: Maximum number of times to attempt this task: can be just 1 if you don't want to reattempt on exceptions. | |
:param fun: Function to call | |
:param args: positional arguments to apply to function | |
:param kwargs: keyword arguments to apply to function | |
""" | |
with self.lock: | |
if self._is_shutdown: | |
raise Exception('Cannot submit tasks to a shutdown threadpool.') | |
self._do_submit(0, max_attempts, fun, args, kwargs) | |
def __enter__(self): | |
"""Enter contextmanager.""" | |
return self | |
def __exit__(self, type, value, traceback): | |
"""Exit contextmanager, shutting down the threadpool & waiting for all tasks to complete or fail.""" | |
self.shutdown() | |
def shutdown(self, wait=True): | |
"""Shutdown threadpool & no longer accept any new tasks. | |
:param wait: This must be true and indicates that we wait for all tasks to complete before continuing. | |
""" | |
if not wait: | |
raise NotImplementedError | |
with self.lock: | |
self._is_shutdown = True | |
self._print('Waiting for `{0}` to complete.'.format(self.name)) | |
self.summary.report() | |
self.queue.join() | |
self._print('Tasks complete.') | |
self.summary.report() | |
def _do_submit(self, attempts, max_attempts, fun, args, kwargs): | |
"""Actually put a task on the task queue.""" | |
with self.lock: | |
if (not self.workers or not self.queue.empty()) and len(self.workers) < self.max_workers: | |
self._workers_created += 1 | |
name = '{0} Worker #{1}/{2}'.format(self.name, self._workers_created, self.max_workers) | |
# Yes, we can have more workers created than max workers if some old workers have retired | |
new_thread = threading.Thread(target=self._worker, name=name) | |
self.workers.add(new_thread) | |
new_thread.daemon = True | |
new_thread.start() | |
if not attempts: | |
self.summary.submit() | |
self.queue.put((fun, args, kwargs, attempts, max_attempts)) | |
def _worker(self): | |
"""Ye actual worker thread.""" | |
try: | |
while True: | |
try: | |
# Get a task | |
fun, args, kwargs, attempts, max_attempts = self.queue.get(True, self.worker_timeout) | |
except queue.Empty: | |
# until we run out | |
return | |
try: | |
if not fun: | |
return | |
try: # do the task | |
fun(*args, **kwargs) | |
except Exception as err: | |
# warn and put back on queue unless max attempts exceeeded | |
self.warn(3, str(err)) | |
attempts += 1 | |
if attempts >= max_attempts: | |
self._print('Max attempts exceeded!') | |
self._print(err) | |
self._print(traceback.format_exc()) | |
self.summary.finish(success=False) | |
else: # put back on queue | |
self._do_submit(attempts, max_attempts, fun, args, kwargs) | |
else: # The task is done forever | |
self.summary.finish(success=True) | |
finally: | |
self.queue.task_done() # Tell the queue the task is done as well. | |
finally: | |
with self.lock: # Remove this from list of workers, so more could be created if necessary. | |
self.workers.remove(threading.current_thread()) | |
def _null_print(*args, **kwargs): | |
pass | |
if __name__ == '__main__': | |
import doctest | |
doctest.testmod() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment