Skip to content

Instantly share code, notes, and snippets.

@JevinJ
Last active March 27, 2022 07:56
Show Gist options
  • Save JevinJ/b7163af961acf92f50b6d0d6efb39daf to your computer and use it in GitHub Desktop.
Save JevinJ/b7163af961acf92f50b6d0d6efb39daf to your computer and use it in GitHub Desktop.
An Interruptible/Pausable thread pool in python, KeyboardInterrupt will stop the pool, KeyboardInterrupt can be caught to start the pool where it left off.
import queue
import threading
import os
import time
import signal
class Worker(threading.Thread):
def __init__(self, tasks, results):
super().__init__()
self.tasks = tasks
self.shutdown_flag = threading.Event()
self.daemon = True
self.results = results
def run(self):
while not self.shutdown_flag.is_set():
try:
task = self.tasks.get_nowait()
self.results.append(task())
except queue.Empty:
break
else:
self.tasks.task_done()
class ThreadPool:
'''
Interruptable thread pool with multiprocessing.map-like function, KeyboardInterrupt stops the pool,
KeyboardInterrupt can be caught and the pool can be continued.
'''
def __init__(self):
signal.signal(signal.SIGINT, self.interrupt_event)
signal.signal(signal.SIGTERM, self.interrupt_event)
self.tasks = queue.Queue()
self.threads = []
self.results = []
def __enter__(self):
return self
def __exit__(self, *args):
self.stop()
def interrupt_event(self, signum, stack):
self.stop()
raise KeyboardInterrupt
def start(self):
for thread in self.threads:
thread.start()
def stop(self):
for thread in self.threads:
thread.shutdown_flag.set()
for thread in self.threads:
thread.join()
def add_to_queue(self, func, *args):
self.tasks.put(functools.partial(func, *args))
def map(self, func, iterable):
for task in iterable:
self.add_to_queue(func, task)
self.reset()
self.start()
while self.isRunning() and not self.tasks.empty():
time.sleep(.01)
self.stop()
return self.results
def reset(self):
self.result = []
self.threads = [Worker(self.tasks, self.results) for t in range(os.cpu_count())]
def isRunning(self):
return any(thread.is_alive() for thread in self.threads)
#Return (num, True) if no exception(success), or (num, False) so we can retry.
def fetch(task_num):
try:
run(task_num)
except:
return (task_num, False)
return (task_num, True)
#Demonstrating failures
def run(task_num):
if random.randint(1, 10) < 5:
print(f'failed: {task_num}')
raise ValueError
print(f'OK: {task_num}')
time.sleep(1)
if __name__ == '__main__':
nums = [i for i in range(100)]
#You can catch KeyboardInterrupt here and continue or pass and the program will exit.
try:
#Keep trying as long as map returns failures
while nums:
with ThreadPool() as pool:
#New nums list of failures to retry
nums = [result[0] for result in pool.map(fetch, nums) if result[1] == False]
except KeyboardInterrupt:
pass
print('done')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment