Created
February 13, 2017 01:59
-
-
Save tanchao90/012721bbd05b70b1d9b3f30326393579 to your computer and use it in GitHub Desktop.
Python 实现的一个线程池
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- encoding:utf-8 -*- | |
###### 线程池 - by 汪蔚 soulww@163.com | |
import threading | |
import traceback | |
def default_fail_handle(): | |
pass | |
ExceptionHandle = default_fail_handle | |
class UnhandledException(Exception): | |
ExceptionHandle() | |
def _handle_thread_exception(request, exc_info): | |
traceback.print_exception(*exc_info) | |
class WorkContext(object): | |
__slots__ = ('taskId', 'args', 'kwds', 'result', 'exception', 'workerId') | |
def __init__(self, taskId, args=None, kwds=None, result=None, exception=None): | |
self.taskId = taskId | |
self.workerId = None | |
self.args = args or [] | |
self.kwds = kwds or {} | |
self.result = result | |
self.exception = exception | |
def __str__(self): | |
data = { | |
'taskId': self.taskId, | |
'workerId': self.workerId, | |
'args': self.args, | |
'kwds': self.kwds, | |
'result': self.result, | |
'exception': str(self.exception) | |
} | |
return str(data) | |
class WorkTask(object): | |
def __init__(self, func, args=None, kwds=None, taskId=None, callback=None, exceptionHandle=UnhandledException): | |
self.taskId = self.genTaskId(taskId) | |
self.context = WorkContext(self.taskId, args, kwds) | |
self.callable = self.genCallable(func) # 需要调用的函数,参)数是context | |
self.callback = callback # 函数完成后的回调,参数是context ) | |
self.ecHandle = exceptionHandle # 异常处理函数,参数是context | |
def __call__(self): | |
self.callable() | |
def genTaskId(self, taskId=None): | |
if taskId is not None: | |
try: | |
taskId = hash(taskId) | |
except TypeError: | |
raise TypeError("taskId must be hashable") | |
return taskId | |
return id(self) | |
def genCallable(self, func): | |
def callable(): | |
try: | |
try: | |
self.context.result = func(*self.context.args, **self.context.kwds) | |
except Exception, e: | |
print e | |
self.callback and self.callback(self.context) | |
except Exception, e: | |
self.context.exception = e | |
self.ecHandle and self.ecHandle(e) | |
return callable | |
class TaskPool(object): | |
class Empty(Exception): | |
"Exception raised by TaskPool get." | |
pass | |
class Full(Exception): | |
"Exception raised by TaskPool put." | |
pass | |
def __init__(self, maxSize=0): | |
# if maxSize <=0, the pool will be infinite | |
if maxSize <= 0: | |
maxSize = float('inf') | |
self.maxSize = maxSize | |
self.size = 0 | |
self.tasks = {} | |
self.taskIdSet = set() | |
self.mutex = threading.Lock() | |
self.hasTask = threading.Semaphore(value=0) | |
self.noTask = threading.Semaphore(value=1) | |
# put a task into pool and return taskId | |
def putTask(self, task): | |
self.mutex.acquire() | |
try: | |
if self.size >= self.maxSize: | |
raise TaskPool.Full | |
if task.taskId not in self.tasks: | |
self.size += 1 | |
self.taskIdSet.add(task.taskId) | |
self.tasks[task.taskId] = task | |
self.hasTask.release() | |
self.noTask.acquire(blocking=False) | |
finally: | |
self.mutex.release() | |
# pop a task from pool and return the task | |
def popTask(self, taskId=None): | |
self.mutex.acquire() | |
task = None | |
try: | |
if self.size == 0: | |
raise TaskPool.Empty | |
if taskId is None: | |
taskId = self.taskIdSet.pop() | |
if taskId in self.tasks: | |
self.size -= 1 | |
self.taskIdSet.discard(taskId) | |
task = self.tasks.pop(taskId) | |
self.hasTask.acquire(blocking=False) | |
if self.size == 0: | |
self.noTask.release() | |
finally: | |
self.mutex.release() | |
return task | |
class WorkerThread(threading.Thread): | |
def __init__(self, threadPool): | |
super(WorkerThread, self).__init__() | |
self.workerId = id(self) | |
self.threadPool = threadPool | |
self.stopSignal = 0 | |
self.joinSignal = 0 | |
self.state = 'busy' # 0 for busy, 1 for idle, 2 for dead | |
def stop(self): | |
self.stopSignal = 1 | |
def rest(self): | |
self.state = 'idle' | |
self.threadPool.transfer(self, 'busy', 'idle') | |
self.threadPool.taskPool.hasTask.acquire() | |
self.state = 'busy' | |
self.threadPool.transfer(self, 'idle', 'busy') | |
def run(self): # 一旦stop后就不能再次run,如果需要多次run,不如再次创建WorkerThread | |
while not self.stopSignal: | |
try: | |
self.state = 'busy' | |
task = self.threadPool.getTask() # 这里不阻塞 | |
task.context.workerId = self.workerId | |
task() | |
except TaskPool.Empty: | |
self.rest() | |
self.state = 'dead' | |
self.threadPool.transfer(self, 'busy', 'dead') | |
self.threadPool = None # 避免循环引用的垃圾回收问题 | |
class ThreadPool(object): | |
def __init__(self, workersNum=0): | |
self.taskPool = TaskPool() | |
self.mutex = threading.RLock() | |
self.taskWorkers = {'busy': {}, 'idle': {}, 'dead': {}} | |
self.taskWorkersNum = {'busy': 0, 'idle': 0, 'dead': 0} | |
self.taskWorkersTotal = 0 | |
self.workerStateSem = { | |
'busy': {'empty': threading.Semaphore(value=1), 'full': threading.Semaphore(value=0)}, | |
'idle': {'empty': threading.Semaphore(value=1), 'full': threading.Semaphore(value=0)}, | |
'dead': {'empty': threading.Semaphore(value=1), 'full': threading.Semaphore(value=0)}, | |
} | |
self.addWorkers(workersNum) | |
def addWorkers(self, num=1): | |
if num <= 0: | |
return | |
self.mutex.acquire() | |
for i in xrange(num): | |
worker = WorkerThread(self) | |
self.taskWorkers['busy'][worker.workerId] = worker | |
self.enterState(worker, 'busy') | |
worker.start() | |
self.taskWorkersTotal += num | |
self.mutex.release() | |
def enterState(self, worker, state): | |
self.mutex.acquire() | |
oldNum = self.taskWorkersNum[state] | |
newNum = oldNum + 1 | |
if newNum > self.taskWorkersTotal: | |
self.mutex.release() | |
return | |
self.taskWorkers[state][worker.workerId] = worker | |
self.taskWorkersNum[state] = newNum | |
oldNum == 0 and self.workerStateSem[state]['empty'].acquire(blocking=False) | |
newNum == self.taskWorkersTotal and self.workerStateSem[state]['full'].release() | |
self.mutex.release() | |
def leaveState(self, worker, state): | |
self.mutex.acquire() | |
oldNum = self.taskWorkersNum[state] | |
newNum = oldNum - 1 | |
if newNum < 0: | |
self.mutex.release() | |
return | |
self.taskWorkers[state][worker.workerId] = worker | |
self.taskWorkersNum[state] = newNum | |
oldNum == self.taskWorkersTotal and self.workerStateSem[state]['full'].acquire(blocking=False) | |
newNum == 0 and self.workerStateSem[state]['empty'].release() | |
self.mutex.release() | |
def transfer(self, worker, oldState, newState): | |
self.leaveState(worker, oldState) | |
self.enterState(worker, newState) | |
def putTask(self, task): | |
self.taskPool.putTask(task) | |
def getTask(self): | |
return self.taskPool.popTask() | |
def stop(self): # 只能调用一次 | |
# 停止所有工作线程 | |
for worker in self.taskWorkers['busy'].itervalues(): | |
worker.stop() | |
# 等待所有工作线程退出busy状态 | |
self.workerStateSem['busy']['empty'].acquire() | |
# 唤醒所有空闲线程,让其进入死亡状态 | |
for worker in self.taskWorkers['idle'].itervalues(): | |
worker.stop() | |
for i in xrange(0, 2 * self.taskWorkersTotal): | |
self.taskPool.hasTask.release() | |
# 等待所有线程进入死亡状态 | |
self.workerStateSem['dead']['full'].acquire() | |
# 合并所有子线程 | |
for workerGroup in self.taskWorkers.itervalues(): | |
for worker in workerGroup.itervalues(): | |
worker.join() | |
def wait(self): | |
self.taskPool.noTask.acquire() | |
self.stop() | |
class TestCase(object): | |
def __init__(self, index): | |
self.testIndex = index | |
self.tasks = [] | |
self.mutex = threading.Lock() | |
self.output = {} # {taskId : output} | |
self.expectOutput = {} | |
def getOutput(self, context): | |
self.mutex.acquire() | |
self.output[context.taskId] = context.result | |
self.mutex.release() | |
def addTask(self, func, inputArgs=None, inputKwargs=None, output=None): | |
task = WorkTask(func, inputArgs, inputKwargs, callback=self.getOutput) | |
self.tasks.append(task) | |
self.expectOutput[task.taskId] = output | |
def run(self): | |
pool = ThreadPool(8) | |
for task in self.tasks: | |
pool.putTask(task) | |
pool.wait() | |
print pool.taskWorkersNum | |
result = cmp(self.output, self.expectOutput) | |
if result == 0: | |
print '\033[1;32;40m' | |
print '*' * 50 | |
print '\t\tTestCase[%s] Pass!' % (self.testIndex) | |
print '*' * 50 | |
print '\033[0m' | |
else: | |
print '\033[1;31;40m' | |
print '*' * 50 | |
print '\t\tTestCase[%s] Fail!' % (self.testIndex) | |
print '- your output is:\n', self.output | |
print '- standard output is:\n', self.expectOutput | |
print '*' * 50 | |
print '\033[0m' | |
if __name__ == '__main__': | |
def add(a, b): | |
return a + b | |
testCase = TestCase(1) | |
testCase.addTask(add, (1, 1), output=2) | |
testCase.addTask(add, (1, 2), output=3) | |
testCase.addTask(add, (1, 3), output=4) | |
testCase.run() | |
def mySort(list): | |
return sorted(list) | |
testCase = TestCase(2) | |
testCase.addTask(mySort, ([1, 2, 3],), output=[1, 2, 3]) | |
testCase.addTask(mySort, ([3, 2, 1],), output=[1, 2, 3]) | |
testCase.addTask(mySort, ([1, 3, 2],), output=[1, 2, 3]) | |
testCase.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
使用Python中的Queue对其改进
管线的思路+多线程
可完成很多流水式的工作,比如图片处理:
其中每个功能采用一种线程处理,出来结束之后放入队列,交由下一个线程处理