Skip to content

Instantly share code, notes, and snippets.

@tanchao90
Created February 13, 2017 01:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tanchao90/012721bbd05b70b1d9b3f30326393579 to your computer and use it in GitHub Desktop.
Save tanchao90/012721bbd05b70b1d9b3f30326393579 to your computer and use it in GitHub Desktop.
Python 实现的一个线程池
# -*- encoding:utf-8 -*-
###### 线程池 - by 汪蔚 soulww@163.com
import threading
import traceback
def default_fail_handle():
pass
ExceptionHandle = default_fail_handle
class UnhandledException(Exception):
ExceptionHandle()
def _handle_thread_exception(request, exc_info):
traceback.print_exception(*exc_info)
class WorkContext(object):
__slots__ = ('taskId', 'args', 'kwds', 'result', 'exception', 'workerId')
def __init__(self, taskId, args=None, kwds=None, result=None, exception=None):
self.taskId = taskId
self.workerId = None
self.args = args or []
self.kwds = kwds or {}
self.result = result
self.exception = exception
def __str__(self):
data = {
'taskId': self.taskId,
'workerId': self.workerId,
'args': self.args,
'kwds': self.kwds,
'result': self.result,
'exception': str(self.exception)
}
return str(data)
class WorkTask(object):
def __init__(self, func, args=None, kwds=None, taskId=None, callback=None, exceptionHandle=UnhandledException):
self.taskId = self.genTaskId(taskId)
self.context = WorkContext(self.taskId, args, kwds)
self.callable = self.genCallable(func) # 需要调用的函数,参)数是context
self.callback = callback # 函数完成后的回调,参数是context )
self.ecHandle = exceptionHandle # 异常处理函数,参数是context
def __call__(self):
self.callable()
def genTaskId(self, taskId=None):
if taskId is not None:
try:
taskId = hash(taskId)
except TypeError:
raise TypeError("taskId must be hashable")
return taskId
return id(self)
def genCallable(self, func):
def callable():
try:
try:
self.context.result = func(*self.context.args, **self.context.kwds)
except Exception, e:
print e
self.callback and self.callback(self.context)
except Exception, e:
self.context.exception = e
self.ecHandle and self.ecHandle(e)
return callable
class TaskPool(object):
class Empty(Exception):
"Exception raised by TaskPool get."
pass
class Full(Exception):
"Exception raised by TaskPool put."
pass
def __init__(self, maxSize=0):
# if maxSize <=0, the pool will be infinite
if maxSize <= 0:
maxSize = float('inf')
self.maxSize = maxSize
self.size = 0
self.tasks = {}
self.taskIdSet = set()
self.mutex = threading.Lock()
self.hasTask = threading.Semaphore(value=0)
self.noTask = threading.Semaphore(value=1)
# put a task into pool and return taskId
def putTask(self, task):
self.mutex.acquire()
try:
if self.size >= self.maxSize:
raise TaskPool.Full
if task.taskId not in self.tasks:
self.size += 1
self.taskIdSet.add(task.taskId)
self.tasks[task.taskId] = task
self.hasTask.release()
self.noTask.acquire(blocking=False)
finally:
self.mutex.release()
# pop a task from pool and return the task
def popTask(self, taskId=None):
self.mutex.acquire()
task = None
try:
if self.size == 0:
raise TaskPool.Empty
if taskId is None:
taskId = self.taskIdSet.pop()
if taskId in self.tasks:
self.size -= 1
self.taskIdSet.discard(taskId)
task = self.tasks.pop(taskId)
self.hasTask.acquire(blocking=False)
if self.size == 0:
self.noTask.release()
finally:
self.mutex.release()
return task
class WorkerThread(threading.Thread):
def __init__(self, threadPool):
super(WorkerThread, self).__init__()
self.workerId = id(self)
self.threadPool = threadPool
self.stopSignal = 0
self.joinSignal = 0
self.state = 'busy' # 0 for busy, 1 for idle, 2 for dead
def stop(self):
self.stopSignal = 1
def rest(self):
self.state = 'idle'
self.threadPool.transfer(self, 'busy', 'idle')
self.threadPool.taskPool.hasTask.acquire()
self.state = 'busy'
self.threadPool.transfer(self, 'idle', 'busy')
def run(self): # 一旦stop后就不能再次run,如果需要多次run,不如再次创建WorkerThread
while not self.stopSignal:
try:
self.state = 'busy'
task = self.threadPool.getTask() # 这里不阻塞
task.context.workerId = self.workerId
task()
except TaskPool.Empty:
self.rest()
self.state = 'dead'
self.threadPool.transfer(self, 'busy', 'dead')
self.threadPool = None # 避免循环引用的垃圾回收问题
class ThreadPool(object):
def __init__(self, workersNum=0):
self.taskPool = TaskPool()
self.mutex = threading.RLock()
self.taskWorkers = {'busy': {}, 'idle': {}, 'dead': {}}
self.taskWorkersNum = {'busy': 0, 'idle': 0, 'dead': 0}
self.taskWorkersTotal = 0
self.workerStateSem = {
'busy': {'empty': threading.Semaphore(value=1), 'full': threading.Semaphore(value=0)},
'idle': {'empty': threading.Semaphore(value=1), 'full': threading.Semaphore(value=0)},
'dead': {'empty': threading.Semaphore(value=1), 'full': threading.Semaphore(value=0)},
}
self.addWorkers(workersNum)
def addWorkers(self, num=1):
if num <= 0:
return
self.mutex.acquire()
for i in xrange(num):
worker = WorkerThread(self)
self.taskWorkers['busy'][worker.workerId] = worker
self.enterState(worker, 'busy')
worker.start()
self.taskWorkersTotal += num
self.mutex.release()
def enterState(self, worker, state):
self.mutex.acquire()
oldNum = self.taskWorkersNum[state]
newNum = oldNum + 1
if newNum > self.taskWorkersTotal:
self.mutex.release()
return
self.taskWorkers[state][worker.workerId] = worker
self.taskWorkersNum[state] = newNum
oldNum == 0 and self.workerStateSem[state]['empty'].acquire(blocking=False)
newNum == self.taskWorkersTotal and self.workerStateSem[state]['full'].release()
self.mutex.release()
def leaveState(self, worker, state):
self.mutex.acquire()
oldNum = self.taskWorkersNum[state]
newNum = oldNum - 1
if newNum < 0:
self.mutex.release()
return
self.taskWorkers[state][worker.workerId] = worker
self.taskWorkersNum[state] = newNum
oldNum == self.taskWorkersTotal and self.workerStateSem[state]['full'].acquire(blocking=False)
newNum == 0 and self.workerStateSem[state]['empty'].release()
self.mutex.release()
def transfer(self, worker, oldState, newState):
self.leaveState(worker, oldState)
self.enterState(worker, newState)
def putTask(self, task):
self.taskPool.putTask(task)
def getTask(self):
return self.taskPool.popTask()
def stop(self): # 只能调用一次
# 停止所有工作线程
for worker in self.taskWorkers['busy'].itervalues():
worker.stop()
# 等待所有工作线程退出busy状态
self.workerStateSem['busy']['empty'].acquire()
# 唤醒所有空闲线程,让其进入死亡状态
for worker in self.taskWorkers['idle'].itervalues():
worker.stop()
for i in xrange(0, 2 * self.taskWorkersTotal):
self.taskPool.hasTask.release()
# 等待所有线程进入死亡状态
self.workerStateSem['dead']['full'].acquire()
# 合并所有子线程
for workerGroup in self.taskWorkers.itervalues():
for worker in workerGroup.itervalues():
worker.join()
def wait(self):
self.taskPool.noTask.acquire()
self.stop()
class TestCase(object):
def __init__(self, index):
self.testIndex = index
self.tasks = []
self.mutex = threading.Lock()
self.output = {} # {taskId : output}
self.expectOutput = {}
def getOutput(self, context):
self.mutex.acquire()
self.output[context.taskId] = context.result
self.mutex.release()
def addTask(self, func, inputArgs=None, inputKwargs=None, output=None):
task = WorkTask(func, inputArgs, inputKwargs, callback=self.getOutput)
self.tasks.append(task)
self.expectOutput[task.taskId] = output
def run(self):
pool = ThreadPool(8)
for task in self.tasks:
pool.putTask(task)
pool.wait()
print pool.taskWorkersNum
result = cmp(self.output, self.expectOutput)
if result == 0:
print '\033[1;32;40m'
print '*' * 50
print '\t\tTestCase[%s] Pass!' % (self.testIndex)
print '*' * 50
print '\033[0m'
else:
print '\033[1;31;40m'
print '*' * 50
print '\t\tTestCase[%s] Fail!' % (self.testIndex)
print '- your output is:\n', self.output
print '- standard output is:\n', self.expectOutput
print '*' * 50
print '\033[0m'
if __name__ == '__main__':
def add(a, b):
return a + b
testCase = TestCase(1)
testCase.addTask(add, (1, 1), output=2)
testCase.addTask(add, (1, 2), output=3)
testCase.addTask(add, (1, 3), output=4)
testCase.run()
def mySort(list):
return sorted(list)
testCase = TestCase(2)
testCase.addTask(mySort, ([1, 2, 3],), output=[1, 2, 3])
testCase.addTask(mySort, ([3, 2, 1],), output=[1, 2, 3])
testCase.addTask(mySort, ([1, 3, 2],), output=[1, 2, 3])
testCase.run()
@tanchao90
Copy link
Author

使用Python中的Queue对其改进

  • Queue队列put、get方法是阻塞时操作,只要put、get成功时,程序才会继续执行
  • 不需要再程序内部不断的检查是否有任务,避免程序陷入持续等待状态

管线的思路+多线程

可完成很多流水式的工作,比如图片处理:

  1. download
  2. resize
  3. save

其中每个功能采用一种线程处理,出来结束之后放入队列,交由下一个线程处理

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment