Created
February 21, 2019 18:16
-
-
Save markns/4c1c834b9ea6678948fe168fe0a63752 to your computer and use it in GitHub Desktop.
Use of python ThreadPoolExecutor that handles cancellation and exceptions correctly
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import concurrent.futures | |
import os | |
import threading | |
import time | |
from random import randint | |
Result = collections.namedtuple("Result", "copied scaled name") | |
Summary = collections.namedtuple("Summary", "todo copied scaled canceled") | |
def main(): | |
print("starting...") | |
concurrency = 4 | |
summary = scale(concurrency) | |
summarize(summary, concurrency) | |
def scale(concurrency): | |
futures = set() | |
event = threading.Event() | |
with concurrent.futures.ThreadPoolExecutor(max_workers=concurrency) as executor: | |
for pid in range(4): | |
future = executor.submit(process, pid, event) | |
futures.add(future) | |
summary = wait_for(futures, event) | |
if summary.canceled: | |
executor.shutdown() | |
return summary | |
# if we caught the KeyboardInterrupt in this function we'd lose the | |
# accumulated todo, copied, scaled counts. | |
def wait_for(futures, event): | |
canceled = False | |
copied = scaled = 0 | |
try: | |
for future in concurrent.futures.as_completed(futures): | |
err = future.exception() | |
if err is None: | |
result = future.result() | |
copied += result.copied | |
scaled += result.scaled | |
print("{} {}".format("copied" if result.copied else | |
"scaled", os.path.basename(result.name))) | |
# elif isinstance(err, Image.Error): | |
# print(str(err), True) | |
else: | |
raise err # Unanticipated | |
except KeyboardInterrupt: | |
print("canceling...") | |
event.set() | |
canceled = True | |
for future in futures: | |
future.cancel() | |
return Summary(len(futures), copied, scaled, canceled) | |
def process(pid, event: threading.Event): | |
iters = 0 | |
while not event.is_set(): | |
if iters % 100 == 0: | |
break | |
iters += 1 | |
if randint(0, 10000) == 4: | |
raise Exception('boom') | |
time.sleep(0.1) | |
print(f'complete {pid}') | |
return Result(1, 0, 'blah') | |
def summarize(summary, concurrency): | |
message = "copied {} scaled {} ".format(summary.copied, summary.scaled) | |
difference = summary.todo - (summary.copied + summary.scaled) | |
if difference: | |
message += "skipped {} ".format(difference) | |
message += "using {} threads".format(concurrency) | |
if summary.canceled: | |
message += " [canceled]" | |
print(message) | |
print() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment