Skip to content

Instantly share code, notes, and snippets.

@jmkacz
Created May 29, 2012 15:28
Show Gist options
  • Save jmkacz/2829062 to your computer and use it in GitHub Desktop.
Save jmkacz/2829062 to your computer and use it in GitHub Desktop.
Dvir's latest version of TProcessPoolServer
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import logging, struct, socket
from multiprocessing import Process, Value, Condition, reduction,Lock
from TServer import TServer
from thrift.transport.TTransport import TTransportException
#import prctl
import signal
import os
import sys
import time
class TProcessPoolServer(TServer):
"""
Server with a fixed size pool of worker subprocesses which service requests.
Note that if you need shared state between the handlers - it's up to you!
Written by Dvir Volk, doat.com
"""
DEFAULT_NUM_WORKERS = 8
CLIENT_TIMEOUT = 20
def __init__(self, * args):
TServer.__init__(self, *args)
self.numWorkers = TProcessPoolServer.DEFAULT_NUM_WORKERS
self.workers = []
self.isRunning = Value('b', False)
self.stopCondition = Condition()
self.postForkCallback = None
self.shutdownCallback = None
self.parentPid = os.getpid()
os.setpgid(self.parentPid, self.parentPid)
self._lock = Lock()
def setPostForkCallback(self, callback):
"""
Set a callback to be called in the workers AFTER they have forked.
This is useful for them to start threads, open sockets to databases, etc
"""
if not callable(callback):
raise TypeError("This is not a callback!")
self.postForkCallback = callback
def setShutdownCallback(self, callback):
"""
Set a callback to be called when we need to shut down the server
"""
if not callable(callback):
raise TypeError("This is not a callback!")
self.shutdownCallback = callback
def setNumWorkers(self, num):
"""Set the number of worker sub procs that should be created"""
self.numWorkers = num
def workerProcess(self, workerNum):
"""Loop around getting clients from the shared queue and process them."""
self.workerNum = workerNum
logging.info("Worker starting! %s %s" % (workerNum, os.getpid()))
if self.postForkCallback:
try:
with self._lock:
self.postForkCallback()
#catch system exit while in post forking
except (KeyboardInterrupt, SystemExit):
logging.info("Worker closing! %s %s", workerNum, os.getpid())
return 0
except Exception, x:
logging.exception(x)
while self.isRunning.value == True:
try:
try:
client = self.serverTransport.accept()
except Exception, e:
logging.warn('socket timed out on accept!')
continue
self.serveClient(client)
except (KeyboardInterrupt, SystemExit):
logging.info("Worker closing! %s %s", workerNum, os.getpid())
break
except Exception, x:
logging.exception(x)
logging.info("Shutting Down")
#Call the shutdown callback if necessary
if self.shutdownCallback:
try:
self.shutdownCallback()
except Exception, e:
logging.exception(e)
logging.info("Process %s exiting!" % os.getpid())
def serveClient(self, client):
"""Process input/output from a client for as long as possible"""
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
except TTransportException, tx:
pass
except (SystemExit, KeyboardInterrupt):
pass
except Exception, x:
logging.exception(x)
try:
itrans.close()
otrans.close()
except Exception, e:
logging.exception(e)
def serve(self):
"""Start a fixed number of worker threads and put client into a queue"""
#this is a shared state that can tell the workers to exit when set as false
self.isRunning.value = True
#first bind and listen to the port
self.serverTransport.listen()
#this is useful if you're constantly opening/closing connections
try:
self.serverTransport.handle.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
struct.pack('ii', 1, 0))
except Exception, e:
logging.error("could not set linger: %s" , e)
#fork the children
for i in range(self.numWorkers):
if not self.isRunning.value:
break
try:
w = Process(target=self.workerProcess, args = (i,), name = 'ServerWorker-%s' % i)
w.daemon = True
self.workers.append(w)
w.start()
#catch system exit while forking children
except (SystemExit, KeyboardInterrupt):
logging.warn("Got interrupt!")
self.isRunning.value = False
break
except Exception, x:
logging.exception(x)
logging.info("Exited forking loop!")
#wait until the condition is set by stop()
while self.isRunning.value:
self.stopCondition.acquire()
try:
self.stopCondition.wait()
self.stopCondition.release()
except (SystemExit, KeyboardInterrupt):
logging.warn("Got interrupt!")
break
except Exception, x:
logging.exception(x)
self.isRunning.value = False
def stop(self):
self.isRunning.value = False
self.stopCondition.acquire()
self.stopCondition.notify_all()
self.stopCondition.release()
logging.info("Stopped process pool server")
for proc in self.workers:
self._log("Joining worker %s. alive? %s" , proc, proc.is_alive())
try:
proc.join(1.0)
#terminate the process anyways
if proc.is_alive():
proc.terminate()
logging.info("Worker %s joined!", proc)
except Exception, e:
logging.exception(e)
proc.terminate()
self.serverTransport.close()
#send all the workers SIGTERM just in case
os.killpg(os.getpgid(self.parentPid), signal.SIGTERM)
os.waitpid(0, 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment