Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Python 实现的一个线程池
# -*- encoding:utf-8 -*-
###### 线程池 - by 汪蔚 soulww@163.com
import threading
import traceback
def default_fail_handle():
pass
ExceptionHandle = default_fail_handle
class UnhandledException(Exception):
ExceptionHandle()
def _handle_thread_exception(request, exc_info):
traceback.print_exception(*exc_info)
class WorkContext(object):
__slots__ = ('taskId', 'args', 'kwds', 'result', 'exception', 'workerId')
def __init__(self, taskId, args=None, kwds=None, result=None, exception=None):
self.taskId = taskId
self.workerId = None
self.args = args or []
self.kwds = kwds or {}
self.result = result
self.exception = exception
def __str__(self):
data = {
'taskId': self.taskId,
'workerId': self.workerId,
'args': self.args,
'kwds': self.kwds,
'result': self.result,
'exception': str(self.exception)
}
return str(data)
class WorkTask(object):
def __init__(self, func, args=None, kwds=None, taskId=None, callback=None, exceptionHandle=UnhandledException):
self.taskId = self.genTaskId(taskId)
self.context = WorkContext(self.taskId, args, kwds)
self.callable = self.genCallable(func) # 需要调用的函数,参)数是context
self.callback = callback # 函数完成后的回调,参数是context )
self.ecHandle = exceptionHandle # 异常处理函数,参数是context
def __call__(self):
self.callable()
def genTaskId(self, taskId=None):
if taskId is not None:
try:
taskId = hash(taskId)
except TypeError:
raise TypeError("taskId must be hashable")
return taskId
return id(self)
def genCallable(self, func):
def callable():
try:
try:
self.context.result = func(*self.context.args, **self.context.kwds)
except Exception, e:
print e
self.callback and self.callback(self.context)
except Exception, e:
self.context.exception = e
self.ecHandle and self.ecHandle(e)
return callable
class TaskPool(object):
class Empty(Exception):
"Exception raised by TaskPool get."
pass
class Full(Exception):
"Exception raised by TaskPool put."
pass
def __init__(self, maxSize=0):
# if maxSize <=0, the pool will be infinite
if maxSize <= 0:
maxSize = float('inf')
self.maxSize = maxSize
self.size = 0
self.tasks = {}
self.taskIdSet = set()
self.mutex = threading.Lock()
self.hasTask = threading.Semaphore(value=0)
self.noTask = threading.Semaphore(value=1)
# put a task into pool and return taskId
def putTask(self, task):
self.mutex.acquire()
try:
if self.size >= self.maxSize:
raise TaskPool.Full
if task.taskId not in self.tasks:
self.size += 1
self.taskIdSet.add(task.taskId)
self.tasks[task.taskId] = task
self.hasTask.release()
self.noTask.acquire(blocking=False)
finally:
self.mutex.release()
# pop a task from pool and return the task
def popTask(self, taskId=None):
self.mutex.acquire()
task = None
try:
if self.size == 0:
raise TaskPool.Empty
if taskId is None:
taskId = self.taskIdSet.pop()
if taskId in self.tasks:
self.size -= 1
self.taskIdSet.discard(taskId)
task = self.tasks.pop(taskId)
self.hasTask.acquire(blocking=False)
if self.size == 0:
self.noTask.release()
finally:
self.mutex.release()
return task
class WorkerThread(threading.Thread):
def __init__(self, threadPool):
super(WorkerThread, self).__init__()
self.workerId = id(self)
self.threadPool = threadPool
self.stopSignal = 0
self.joinSignal = 0
self.state = 'busy' # 0 for busy, 1 for idle, 2 for dead
def stop(self):
self.stopSignal = 1
def rest(self):
self.state = 'idle'
self.threadPool.transfer(self, 'busy', 'idle')
self.threadPool.taskPool.hasTask.acquire()
self.state = 'busy'
self.threadPool.transfer(self, 'idle', 'busy')
def run(self): # 一旦stop后就不能再次run,如果需要多次run,不如再次创建WorkerThread
while not self.stopSignal:
try:
self.state = 'busy'
task = self.threadPool.getTask() # 这里不阻塞
task.context.workerId = self.workerId
task()
except TaskPool.Empty:
self.rest()
self.state = 'dead'
self.threadPool.transfer(self, 'busy', 'dead')
self.threadPool = None # 避免循环引用的垃圾回收问题
class ThreadPool(object):
def __init__(self, workersNum=0):
self.taskPool = TaskPool()
self.mutex = threading.RLock()
self.taskWorkers = {'busy': {}, 'idle': {}, 'dead': {}}
self.taskWorkersNum = {'busy': 0, 'idle': 0, 'dead': 0}
self.taskWorkersTotal = 0
self.workerStateSem = {
'busy': {'empty': threading.Semaphore(value=1), 'full': threading.Semaphore(value=0)},
'idle': {'empty': threading.Semaphore(value=1), 'full': threading.Semaphore(value=0)},
'dead': {'empty': threading.Semaphore(value=1), 'full': threading.Semaphore(value=0)},
}
self.addWorkers(workersNum)
def addWorkers(self, num=1):
if num <= 0:
return
self.mutex.acquire()
for i in xrange(num):
worker = WorkerThread(self)
self.taskWorkers['busy'][worker.workerId] = worker
self.enterState(worker, 'busy')
worker.start()
self.taskWorkersTotal += num
self.mutex.release()
def enterState(self, worker, state):
self.mutex.acquire()
oldNum = self.taskWorkersNum[state]
newNum = oldNum + 1
if newNum > self.taskWorkersTotal:
self.mutex.release()
return
self.taskWorkers[state][worker.workerId] = worker
self.taskWorkersNum[state] = newNum
oldNum == 0 and self.workerStateSem[state]['empty'].acquire(blocking=False)
newNum == self.taskWorkersTotal and self.workerStateSem[state]['full'].release()
self.mutex.release()
def leaveState(self, worker, state):
self.mutex.acquire()
oldNum = self.taskWorkersNum[state]
newNum = oldNum - 1
if newNum < 0:
self.mutex.release()
return
self.taskWorkers[state][worker.workerId] = worker
self.taskWorkersNum[state] = newNum
oldNum == self.taskWorkersTotal and self.workerStateSem[state]['full'].acquire(blocking=False)
newNum == 0 and self.workerStateSem[state]['empty'].release()
self.mutex.release()
def transfer(self, worker, oldState, newState):
self.leaveState(worker, oldState)
self.enterState(worker, newState)
def putTask(self, task):
self.taskPool.putTask(task)
def getTask(self):
return self.taskPool.popTask()
def stop(self): # 只能调用一次
# 停止所有工作线程
for worker in self.taskWorkers['busy'].itervalues():
worker.stop()
# 等待所有工作线程退出busy状态
self.workerStateSem['busy']['empty'].acquire()
# 唤醒所有空闲线程,让其进入死亡状态
for worker in self.taskWorkers['idle'].itervalues():
worker.stop()
for i in xrange(0, 2 * self.taskWorkersTotal):
self.taskPool.hasTask.release()
# 等待所有线程进入死亡状态
self.workerStateSem['dead']['full'].acquire()
# 合并所有子线程
for workerGroup in self.taskWorkers.itervalues():
for worker in workerGroup.itervalues():
worker.join()
def wait(self):
self.taskPool.noTask.acquire()
self.stop()
class TestCase(object):
def __init__(self, index):
self.testIndex = index
self.tasks = []
self.mutex = threading.Lock()
self.output = {} # {taskId : output}
self.expectOutput = {}
def getOutput(self, context):
self.mutex.acquire()
self.output[context.taskId] = context.result
self.mutex.release()
def addTask(self, func, inputArgs=None, inputKwargs=None, output=None):
task = WorkTask(func, inputArgs, inputKwargs, callback=self.getOutput)
self.tasks.append(task)
self.expectOutput[task.taskId] = output
def run(self):
pool = ThreadPool(8)
for task in self.tasks:
pool.putTask(task)
pool.wait()
print pool.taskWorkersNum
result = cmp(self.output, self.expectOutput)
if result == 0:
print '\033[1;32;40m'
print '*' * 50
print '\t\tTestCase[%s] Pass!' % (self.testIndex)
print '*' * 50
print '\033[0m'
else:
print '\033[1;31;40m'
print '*' * 50
print '\t\tTestCase[%s] Fail!' % (self.testIndex)
print '- your output is:\n', self.output
print '- standard output is:\n', self.expectOutput
print '*' * 50
print '\033[0m'
if __name__ == '__main__':
def add(a, b):
return a + b
testCase = TestCase(1)
testCase.addTask(add, (1, 1), output=2)
testCase.addTask(add, (1, 2), output=3)
testCase.addTask(add, (1, 3), output=4)
testCase.run()
def mySort(list):
return sorted(list)
testCase = TestCase(2)
testCase.addTask(mySort, ([1, 2, 3],), output=[1, 2, 3])
testCase.addTask(mySort, ([3, 2, 1],), output=[1, 2, 3])
testCase.addTask(mySort, ([1, 3, 2],), output=[1, 2, 3])
testCase.run()
@tanchao90

This comment has been minimized.

Copy link
Owner Author

commented Feb 16, 2017

使用Python中的Queue对其改进

  • Queue队列put、get方法是阻塞时操作,只要put、get成功时,程序才会继续执行
  • 不需要再程序内部不断的检查是否有任务,避免程序陷入持续等待状态

管线的思路+多线程

可完成很多流水式的工作,比如图片处理:

  1. download
  2. resize
  3. save

其中每个功能采用一种线程处理,出来结束之后放入队列,交由下一个线程处理

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.