Skip to content

Instantly share code, notes, and snippets.

@sega-yarkin
Created January 5, 2020 07:55
Show Gist options
  • Save sega-yarkin/b74a16c8a253bc013761a5610f5cfdc7 to your computer and use it in GitHub Desktop.
Save sega-yarkin/b74a16c8a253bc013761a5610f5cfdc7 to your computer and use it in GitHub Desktop.
Python RPC-like server module
#!/usr/bin/env python3
import os
import sys
import time
import prctl # pylint: disable=import-error
import signal
import socket
import threading
import logging
import json
import importlib
import traceback
import io
class Manager:
PORT = int(os.environ.get("MANAGER_PORT", "8201"))
def __init__(self):
self.pid = os.getpid()
self.sock = None
self.set_signal_handlers()
self.preload()
self.start_server()
def start_server(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
self.sock.bind(("0.0.0.0", self.PORT))
self.sock.listen(3) # backlog
logging.info(f"Listening on {self.PORT}", extra={'pid': self.pid})
def stop_server(self):
self.sock.close()
logging.info("Listening socket is closed", extra={'pid': self.pid})
def on_sigint(self, _signo, _stack):
logging.info("Got SIGINT, terminating", extra={'pid': self.pid})
self.stop_server()
sys.exit(0)
def on_sigchld(self, signo, _stack):
pid, _signo, _ru = os.wait3(os.WNOHANG)
if pid > 0:
logging.debug(f"Got SIGCHLD for PID={pid} and signal={signo}", extra={'pid': self.pid})
try:
os.killpg(pid, signal.SIGKILL)
except:
pass
def set_signal_handlers(self):
signal.signal(signal.SIGINT, self.on_sigint)
signal.signal(signal.SIGCHLD, self.on_sigchld)
def preload(self):
module = os.environ.get("PRELOAD_MODULE")
if module: importlib.import_module(module)
def loop(self):
while True:
logging.debug("Waiting for connection...", extra={'pid': self.pid})
conn, client_address = self.sock.accept()
worker = self.get_worker(conn, client_address)
if worker is None: continue
worker.run()
return
def get_worker(self, conn, client_id):
worker_pid = os.fork()
if worker_pid > 0:
logging.debug("Forked worker PID={worker_pid} for {client_id}", extra={'pid': self.pid})
conn.close()
else:
self.sock.close()
worker = Worker(conn, client_id)
return worker
class Worker:
CONNECTION_TIMEOUT = 3.0
def __init__(self, conn, client_id):
prctl.set_proctitle("worker: init")
self.conn = conn
self.client_id = client_id
self.pid = os.getpid()
self.set_signal_handlers()
prctl.set_pdeathsig(signal.SIGTERM) # if parent dies
os.setpgrp() # creare process group
self.monitor = threading.Thread(target=Worker.conn_monitor, args=(self.conn, self.pid))
self.monitor.daemon = True
self.terminating = False
logging.info("Worker created", extra={'pid': self.pid})
def on_sigterm(self, _signal, _stack):
self.terminate()
def set_signal_handlers(self):
signal.signal(signal.SIGTERM, self.on_sigterm)
def get_codec(self):
return JsonCodec(self.conn)
def run(self):
try:
command = Command(self.get_codec()).read()
logging.debug("Received: %s", command.request, extra={'pid': self.pid})
self.monitor.start()
prctl.set_proctitle(f"worker: {command.full_name}")
command.execute()
except socket.timeout:
logging.error("Connection operation timeout!", extra={'pid': self.pid})
except Codec.Error as e:
logging.error("Codec error: %s", e, extra={'pid': self.pid})
except Command.Error as e:
logging.error("Command error: %s", e, extra={'pid': self.pid})
except Exception as e:
logging.error("Unexpected error: %s", e, extra={'pid': self.pid})
self.terminate()
@staticmethod
def conn_monitor(conn, pid):
try:
data = conn.recv(1)
if data != b'':
logging.error("Data is received while worker is running: %s", data, extra={'pid': pid})
else:
logging.error("Connection unexpectedly closed!", extra={'pid': pid})
except:
pass
os.kill(pid, signal.SIGTERM)
def terminate(self):
if self.terminating: return
self.terminating = True
self.conn.shutdown(socket.SHUT_RDWR)
self.conn.close()
logging.info("Terminated", extra={'pid': self.pid})
os.killpg(0, signal.SIGKILL) # kill all processes in the group
class Command:
def __init__(self, codec):
self._codec = codec
self._request = None
self.module = None
self.function = None
self.args = None
@property
def request(self): return self._request
@property
def full_name(self): return f"{self.module}.{self.function}"
def read(self):
req = self._codec.read()
self._request = req
if not isinstance(req, dict):
raise Command.Error("Root object is not a dictionary")
self.module = req.get('module')
self.function = req.get('function')
self.args = req.get('args')
if not isinstance(self.module, str) or self.module == "":
raise Command.Error("'module' should be non-empty string")
if not isinstance(self.function, str) or self.function == "":
raise Command.Error("'function' should be non-empty string")
if not isinstance(self.args, list):
raise Command.Error("'args' should be a list")
return self
def respond_ok(self, result):
self._codec.write({'status': "OK", 'result': result})
def respond_error(self, error):
self._codec.write({'status': "ERROR", 'error': error})
def execute(self):
parts = self.function.split(".")
try:
target = importlib.import_module(self.module)
for part in parts:
target = getattr(target, part)
result = target(*self.args)
self.respond_ok(result)
except:
exc_type, exc_value, tb = sys.exc_info()
tb = traceback.format_list(traceback.extract_tb(tb))
tb.reverse()
self.respond_error({
'type': f"{exc_type.__module__}.{exc_type.__name__}",
'exception': str(exc_value),
'traceback': tb,
})
class Error(ValueError):
def __init__(self, errmsg):
ValueError.__init__(self, errmsg)
class Codec:
READ_TIMEOUT = 300.0
RECV_BUFF_SIZE = 65536
def __init__(self, conn):
self.conn = conn
def read(self):
orig_timeout = self.conn.gettimeout()
self.conn.settimeout(self.READ_TIMEOUT)
data = self._read()
self.conn.settimeout(orig_timeout)
return data
def _read(self):
return self.conn.makefile('rb').readall()
class Error(ValueError):
def __init__(self, errmsg, orig = None):
ValueError.__init__(self, errmsg)
self.orig = orig
class JsonCodec(Codec):
def write(self, data):
raw_data = json.dumps(data, allow_nan = False)
self.conn.sendall(raw_data.encode('utf-8'))
def _read(self):
buff = io.BufferedReader(self.conn.makefile('rb'), buffer_size=self.RECV_BUFF_SIZE)
data = buff.readline()
try:
return json.loads(data)
except json.JSONDecodeError as err:
raise Codec.Error("Cannot decode input as JSON", err)
if __name__ == "__main__":
logging.basicConfig(format='[%(pid)3s] %(message)s', level=logging.DEBUG)
manager = Manager()
manager.loop()
#
# docker run -it --rm -p 8201:8201 -v $PWD:/srv python:3-alpine /bin/sh
# apk add -q --no-cache procps build-base libcap-dev
# easy_install python-prctl
# export PYTHONPATH=/srv
#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment