Created
January 5, 2020 07:55
-
-
Save sega-yarkin/b74a16c8a253bc013761a5610f5cfdc7 to your computer and use it in GitHub Desktop.
Python RPC-like server module
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import os | |
import sys | |
import time | |
import prctl # pylint: disable=import-error | |
import signal | |
import socket | |
import threading | |
import logging | |
import json | |
import importlib | |
import traceback | |
import io | |
class Manager: | |
PORT = int(os.environ.get("MANAGER_PORT", "8201")) | |
def __init__(self): | |
self.pid = os.getpid() | |
self.sock = None | |
self.set_signal_handlers() | |
self.preload() | |
self.start_server() | |
def start_server(self): | |
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) | |
self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) | |
self.sock.bind(("0.0.0.0", self.PORT)) | |
self.sock.listen(3) # backlog | |
logging.info(f"Listening on {self.PORT}", extra={'pid': self.pid}) | |
def stop_server(self): | |
self.sock.close() | |
logging.info("Listening socket is closed", extra={'pid': self.pid}) | |
def on_sigint(self, _signo, _stack): | |
logging.info("Got SIGINT, terminating", extra={'pid': self.pid}) | |
self.stop_server() | |
sys.exit(0) | |
def on_sigchld(self, signo, _stack): | |
pid, _signo, _ru = os.wait3(os.WNOHANG) | |
if pid > 0: | |
logging.debug(f"Got SIGCHLD for PID={pid} and signal={signo}", extra={'pid': self.pid}) | |
try: | |
os.killpg(pid, signal.SIGKILL) | |
except: | |
pass | |
def set_signal_handlers(self): | |
signal.signal(signal.SIGINT, self.on_sigint) | |
signal.signal(signal.SIGCHLD, self.on_sigchld) | |
def preload(self): | |
module = os.environ.get("PRELOAD_MODULE") | |
if module: importlib.import_module(module) | |
def loop(self): | |
while True: | |
logging.debug("Waiting for connection...", extra={'pid': self.pid}) | |
conn, client_address = self.sock.accept() | |
worker = self.get_worker(conn, client_address) | |
if worker is None: continue | |
worker.run() | |
return | |
def get_worker(self, conn, client_id): | |
worker_pid = os.fork() | |
if worker_pid > 0: | |
logging.debug("Forked worker PID={worker_pid} for {client_id}", extra={'pid': self.pid}) | |
conn.close() | |
else: | |
self.sock.close() | |
worker = Worker(conn, client_id) | |
return worker | |
class Worker: | |
CONNECTION_TIMEOUT = 3.0 | |
def __init__(self, conn, client_id): | |
prctl.set_proctitle("worker: init") | |
self.conn = conn | |
self.client_id = client_id | |
self.pid = os.getpid() | |
self.set_signal_handlers() | |
prctl.set_pdeathsig(signal.SIGTERM) # if parent dies | |
os.setpgrp() # creare process group | |
self.monitor = threading.Thread(target=Worker.conn_monitor, args=(self.conn, self.pid)) | |
self.monitor.daemon = True | |
self.terminating = False | |
logging.info("Worker created", extra={'pid': self.pid}) | |
def on_sigterm(self, _signal, _stack): | |
self.terminate() | |
def set_signal_handlers(self): | |
signal.signal(signal.SIGTERM, self.on_sigterm) | |
def get_codec(self): | |
return JsonCodec(self.conn) | |
def run(self): | |
try: | |
command = Command(self.get_codec()).read() | |
logging.debug("Received: %s", command.request, extra={'pid': self.pid}) | |
self.monitor.start() | |
prctl.set_proctitle(f"worker: {command.full_name}") | |
command.execute() | |
except socket.timeout: | |
logging.error("Connection operation timeout!", extra={'pid': self.pid}) | |
except Codec.Error as e: | |
logging.error("Codec error: %s", e, extra={'pid': self.pid}) | |
except Command.Error as e: | |
logging.error("Command error: %s", e, extra={'pid': self.pid}) | |
except Exception as e: | |
logging.error("Unexpected error: %s", e, extra={'pid': self.pid}) | |
self.terminate() | |
@staticmethod | |
def conn_monitor(conn, pid): | |
try: | |
data = conn.recv(1) | |
if data != b'': | |
logging.error("Data is received while worker is running: %s", data, extra={'pid': pid}) | |
else: | |
logging.error("Connection unexpectedly closed!", extra={'pid': pid}) | |
except: | |
pass | |
os.kill(pid, signal.SIGTERM) | |
def terminate(self): | |
if self.terminating: return | |
self.terminating = True | |
self.conn.shutdown(socket.SHUT_RDWR) | |
self.conn.close() | |
logging.info("Terminated", extra={'pid': self.pid}) | |
os.killpg(0, signal.SIGKILL) # kill all processes in the group | |
class Command: | |
def __init__(self, codec): | |
self._codec = codec | |
self._request = None | |
self.module = None | |
self.function = None | |
self.args = None | |
@property | |
def request(self): return self._request | |
@property | |
def full_name(self): return f"{self.module}.{self.function}" | |
def read(self): | |
req = self._codec.read() | |
self._request = req | |
if not isinstance(req, dict): | |
raise Command.Error("Root object is not a dictionary") | |
self.module = req.get('module') | |
self.function = req.get('function') | |
self.args = req.get('args') | |
if not isinstance(self.module, str) or self.module == "": | |
raise Command.Error("'module' should be non-empty string") | |
if not isinstance(self.function, str) or self.function == "": | |
raise Command.Error("'function' should be non-empty string") | |
if not isinstance(self.args, list): | |
raise Command.Error("'args' should be a list") | |
return self | |
def respond_ok(self, result): | |
self._codec.write({'status': "OK", 'result': result}) | |
def respond_error(self, error): | |
self._codec.write({'status': "ERROR", 'error': error}) | |
def execute(self): | |
parts = self.function.split(".") | |
try: | |
target = importlib.import_module(self.module) | |
for part in parts: | |
target = getattr(target, part) | |
result = target(*self.args) | |
self.respond_ok(result) | |
except: | |
exc_type, exc_value, tb = sys.exc_info() | |
tb = traceback.format_list(traceback.extract_tb(tb)) | |
tb.reverse() | |
self.respond_error({ | |
'type': f"{exc_type.__module__}.{exc_type.__name__}", | |
'exception': str(exc_value), | |
'traceback': tb, | |
}) | |
class Error(ValueError): | |
def __init__(self, errmsg): | |
ValueError.__init__(self, errmsg) | |
class Codec: | |
READ_TIMEOUT = 300.0 | |
RECV_BUFF_SIZE = 65536 | |
def __init__(self, conn): | |
self.conn = conn | |
def read(self): | |
orig_timeout = self.conn.gettimeout() | |
self.conn.settimeout(self.READ_TIMEOUT) | |
data = self._read() | |
self.conn.settimeout(orig_timeout) | |
return data | |
def _read(self): | |
return self.conn.makefile('rb').readall() | |
class Error(ValueError): | |
def __init__(self, errmsg, orig = None): | |
ValueError.__init__(self, errmsg) | |
self.orig = orig | |
class JsonCodec(Codec): | |
def write(self, data): | |
raw_data = json.dumps(data, allow_nan = False) | |
self.conn.sendall(raw_data.encode('utf-8')) | |
def _read(self): | |
buff = io.BufferedReader(self.conn.makefile('rb'), buffer_size=self.RECV_BUFF_SIZE) | |
data = buff.readline() | |
try: | |
return json.loads(data) | |
except json.JSONDecodeError as err: | |
raise Codec.Error("Cannot decode input as JSON", err) | |
if __name__ == "__main__": | |
logging.basicConfig(format='[%(pid)3s] %(message)s', level=logging.DEBUG) | |
manager = Manager() | |
manager.loop() | |
# | |
# docker run -it --rm -p 8201:8201 -v $PWD:/srv python:3-alpine /bin/sh | |
# apk add -q --no-cache procps build-base libcap-dev | |
# easy_install python-prctl | |
# export PYTHONPATH=/srv | |
# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment