Skip to content

Instantly share code, notes, and snippets.

@reiver-dev
Last active April 23, 2019 00:26
Show Gist options
  • Save reiver-dev/83834a391e69941830a4640408568e8b to your computer and use it in GitHub Desktop.
Save reiver-dev/83834a391e69941830a4640408568e8b to your computer and use it in GitHub Desktop.
Running processes and passing stdio fds over unix sockets, might be usefult for sidecar containers
import os
import sys
from array import array
from socket import (socket, AF_UNIX, SOCK_SEQPACKET,
CMSG_LEN, SOL_SOCKET, SCM_RIGHTS)
from errno import EADDRINUSE
from asyncio import (
Future, Task, AbstractEventLoop,
get_running_loop as _get_running_loop, set_event_loop,
create_subprocess_exec,
wait as async_wait, FIRST_COMPLETED
)
from asyncio.unix_events import ( # type: ignore
_UnixSelectorEventLoop as UnixLoop
)
from asyncio.runners import ( # type: ignore
_cancel_all_tasks as cancel_all_tasks
)
from asyncio.subprocess import Process
import logging
from subprocess import DEVNULL
from pathlib import Path
from argparse import ArgumentParser, ArgumentTypeError
import json
import struct
import signal
from contextlib import contextmanager
from typing import (Sequence, Set, Mapping,
Optional, Tuple, Any, Iterable)
DEFAULT_LIMIT = 2 ** 16
AncData = Tuple[int, int, bytes]
Msg = Tuple[bytes, Tuple[AncData, ...], int, Any]
NAME = __name__
if NAME == '__main__':
NAME = Path(__file__).stem
_log = logging.getLogger(NAME)
debug = _log.debug
def itob(value: int) -> bytes:
return struct.pack('<i', value)
def btoi(value: bytes) -> int:
return struct.unpack('<i', value)[0]
class Loop(UnixLoop, AbstractEventLoop):
async def sock_recvmsg(self, sock: socket, bufsize: int,
ancbufsize: int = 0) -> Msg:
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
future = self.create_future()
self._sock_recvmsg(future, None, sock, bufsize, ancbufsize)
return await future
async def sock_sendmsg(self, sock: socket,
data: Iterable[bytes],
ancdata: Iterable[AncData] = ()) -> int:
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
future = self.create_future()
self._sock_sendmsg(future, None, sock, data, ancdata)
return await future
def _sock_recvmsg(self, future: Future, registered_fd: Optional[int],
sock: socket, bufsize: int, ancbufsize: int = 0,
flags: int = 0):
if registered_fd is not None:
self.remove_reader(registered_fd)
if future.cancelled():
return
try:
debug('recvmsg data=%r ancdata=%r', bufsize, ancbufsize)
result = sock.recvmsg(bufsize, ancbufsize)
except (BlockingIOError, InterruptedError):
debug('wouldblock')
fd = sock.fileno()
self.add_reader(fd, self._sock_recvmsg, future, fd,
sock, bufsize, ancbufsize, flags)
except Exception as exc:
future.set_exception(exc)
else:
future.set_result(result)
def _sock_sendmsg(self, future: Future, registered_fd: Optional[int],
sock: socket, data: Iterable[bytes],
ancdata: Iterable[AncData]):
if registered_fd is not None:
self.remove_writer(registered_fd)
if future.cancelled():
return
try:
debug('sendmsg data=%r ancdata=%r', data, ancdata)
result = sock.sendmsg(data, ancdata)
except (BlockingIOError, InterruptedError):
debug('wouldblock')
fd = sock.fileno()
self.add_writer(fd, self._sock_sendmsg, future,
fd, sock, data, ancdata)
except Exception as exc:
future.set_exception(exc)
else:
future.set_result(result)
def get_running_loop() -> Loop:
return _get_running_loop() # type: ignore
async def sendfd(loop: Loop, sock: socket,
msg: Iterable[bytes],
fds: Iterable[int]) -> int:
fdsmsg = []
if fds:
fdsmsg.append((SOL_SOCKET, SCM_RIGHTS, bytes(array('i', fds))))
result = await loop.sock_sendmsg(sock, msg, fdsmsg)
return result
async def recvfd(loop: Loop, sock: socket, msglen: int,
maxfds: int = 3) -> Tuple[bytes, Sequence[int]]:
fds = array('i')
msg, ancdata, _flags, _addr = await loop.sock_recvmsg(
sock, msglen, 0 if not maxfds else CMSG_LEN(maxfds * fds.itemsize)
)
debug('recvfd msg=%r acndata=%r flags=%d addr=%s',
msg, ancdata, _flags, _addr)
for level, msgtype, data in ancdata:
if level == SOL_SOCKET and msgtype == SCM_RIGHTS:
fds.frombytes(data[:len(data) - len(data) % fds.itemsize])
return msg, list(fds)
def env_var_pair(value):
try:
name, _, value = value.partition('=')
return (name, value)
except Exception:
raise ArgumentTypeError('{} is not value env var pair'.format(
value
))
def setup_exec_arguments(parser: ArgumentParser):
parser.add_argument('-c', '--connect', type=Path,
metavar='PATH',
help='path to server socket')
parser.add_argument('-e', '--env', type=env_var_pair,
metavar='KEY=VAL',
help='environment variable pairs',
action='append')
parser.add_argument('-w', '--workdir', metavar='PATH',
help='working directory')
parser.add_argument('program', nargs='...',
help='program to execute')
def setup_serve_arguments(parser: ArgumentParser):
parser.add_argument('path', type=Path, help='path to unix socket')
def server_socket(path: Path) -> socket:
sock = socket(family=AF_UNIX, type=SOCK_SEQPACKET)
try:
sock.bind(os.fspath(path))
except OSError as exc:
sock.close()
if exc.errno == EADDRINUSE:
msg = 'Address `{}` is already in use'.format(path)
raise OSError(EADDRINUSE, msg) from None
else:
raise
except Exception:
sock.close()
raise
sock.setblocking(False)
return sock
def client_socket(path: Path) -> socket:
sock = socket(family=AF_UNIX, type=SOCK_SEQPACKET)
try:
sock.connect(os.fspath(path))
sock.setblocking(False)
return sock
except Exception:
sock.close()
raise
def gather_fds(fds: Sequence[int], items: Sequence[int]) -> Sequence[int]:
return list(map(lambda x: x[0], zip(fds, filter(lambda x: x >= 0, items))))
async def handle_process(loop: Loop, sock: socket, process: Process) -> int:
pid = process.pid
inp = loop.create_task(loop.sock_recvmsg(sock, 4096))
finish = loop.create_task(process.wait())
done: Set[Task]
pending: Set[Task]
pending = {inp, finish}
while True:
debug('waiting for process=%d to finish', pid)
(done, pending) = await async_wait(pending, # type: ignore
return_when=FIRST_COMPLETED)
debug('event occured')
if finish in done:
inp.cancel()
retcode = finish.result()
debug('process finished pid=%d ret=%d', pid, retcode)
await loop.sock_sendmsg(sock, [itob(retcode)])
return retcode
if inp in done:
debug('signal received')
msg = inp.result()[0]
if not msg:
debug('client disconnected')
process.kill()
return await finish
if len(msg) != 4:
raise ValueError('wrong signal data: ' + str(signal))
sigval = btoi(msg)
debug('process signal request pid=%d sig=%d', pid, sigval)
process.send_signal(sigval)
inp = loop.create_task(loop.sock_recvmsg(sock, 4096))
pending.add(inp)
async def handle_client(loop: Loop, sock: socket):
msg, fds = await recvfd(loop, sock, 4096, 3)
request = json.loads(msg.decode('utf-8'))
debug('request=%s fds=%s', request, fds)
argv = request['argv']
env = request.get('env', None)
cwd = request.get('cwd', None)
streams = request.get('io', {})
sin = streams.get('in', DEVNULL)
sout = streams.get('out', DEVNULL)
serr = streams.get('err', DEVNULL)
debug('process argv=%s env=%s cwd=%s', argv, env, cwd)
environ = None
if env:
env = env
environ = os.environ.copy()
for name, value in env.items():
environ[name] = os.path.expandvars(value)
workdir = None
if cwd:
workdir = str(cwd)
sin, sout, serr = gather_fds(fds, [streams.get(n, DEVNULL)
for n in ('in', 'out', 'err')])
try:
process = await create_subprocess_exec(*argv,
stdin=sin,
stdout=sout,
stderr=serr,
env=environ,
cwd=workdir)
except Exception as err:
result = {'success': False,
'message': str(err),
'errno': getattr(err, 'errno', None),
'pid': None}
await loop.sock_sendmsg(
sock, [json.dumps(result).encode('utf-8')]
)
return
finally:
for fd in (sin, sout, serr):
if fd >= 0:
os.close(fd)
pid = process.pid
debug('process pid=%d', pid)
await loop.sock_sendmsg(
sock, [json.dumps({'success': True, 'pid': pid}).encode('utf-8')]
)
try:
return await handle_process(loop, sock, process)
except Exception:
process.kill()
await process.wait()
raise
def sock_close_cb(sock: socket):
def finish(task: Task):
debug('closing socket sock=%r after task=%r',
sock, task)
sock.close()
return finish
async def accept(loop: Loop, server_sock: socket):
server_sock.listen()
while True:
client, addr = await loop.sock_accept(server_sock)
debug('connected sock=%r addr=%r', client, addr)
loop.create_task(handle_client(loop, client))
def forwarded_signals() -> Iterable[int]:
return frozenset(
sig
for sig in
map(lambda name: getattr(signal, name, None), (
'SIGABRT',
'SIGALRM',
'SIGBUS',
'SIGCHLD',
'SIGCLD',
'SIGCONT',
'SIGEMT',
'SIGFPE',
'SIGHUP',
'SIGILL',
'SIGINFO',
'SIGINT',
'SIGIO',
'SIGIOT',
'SIGKILL',
'SIGLOST',
'SIGPIPE',
'SIGPOLL',
'SIGPROF',
'SIGPWR',
'SIGQUIT',
'SIGSEGV',
'SIGSTKFLT',
'SIGSTOP',
'SIGTSTP',
'SIGSYS',
'SIGTERM',
'SIGTRAP',
'SIGTTIN',
'SIGTTOU',
'SIGUNUSED',
'SIGURG',
'SIGUSR1',
'SIGUSR2',
'SIGVTALRM',
'SIGXCPU',
'SIGXFSZ',
'SIGWINCH',
))
if sig is not None and sig != 9 and sig != 19
)
async def send_signal(loop, sock, sig):
await loop.sock_sendmsg(sock, [itob(sig)])
@contextmanager
def setup_signal_forwarding(loop: Loop, sock: socket):
def forward_signal(sig):
loop.create_task(send_signal(loop, sock, sig))
signals = forwarded_signals()
for sig in signals:
loop.add_signal_handler(sig, forward_signal, sig)
try:
yield
finally:
for sig in signals:
loop.remove_signal_handler(sig)
async def connect(loop: Loop, client_sock: socket, argv: Sequence[str],
env: Mapping[str, str] = None, cwd: str = None) -> int:
streams = {
'in': sys.stdin.fileno(),
'out': sys.stdout.fileno(),
'err': sys.stderr.fileno()
}
request = {
'argv': argv,
'env': env,
'cwd': cwd,
'io': streams,
}
debug('requesting process %s', request)
await sendfd(loop, client_sock,
[json.dumps(request).encode('utf-8')],
list(streams.values()))
debug('waiting')
with setup_signal_forwarding(loop, client_sock):
msg = (await loop.sock_recvmsg(client_sock, 4096))[0]
procstart = json.loads(msg.decode('utf-8'))
debug('process response=%s', procstart)
if not procstart['success']:
return 127
debug('waiting process to finish')
msg = (await loop.sock_recvmsg(client_sock, 4096))[0]
if not msg or len(msg) != 4:
_log.error('invalid response=%r', msg)
return 127
retcode = btoi(msg)
debug('return code %d', retcode)
return retcode
async def create_server(path: Path, loop: Loop = None):
sock = server_socket(path)
try:
if loop is None:
loop = get_running_loop()
await accept(loop, sock)
finally:
if sock:
sock.close()
os.unlink(path)
async def create_client(path: Path, argv: Sequence[str],
env: Mapping[str, str] = None,
cwd: str = None,
loop: Loop = None) -> int:
sock = client_socket(path)
try:
if loop is None:
loop = get_running_loop()
return await connect(loop, sock, argv, env, cwd)
finally:
if sock:
sock.close()
def serve(argv) -> int:
run_forever(create_server(argv.path))
return 0
def execute(argv) -> int:
if not argv.program:
return 0
return run_forever(create_client(
argv.connect, argv.program,
None if argv.env is None else dict(argv.env),
argv.workdir
))
def run_forever(main):
loop = Loop()
try:
set_event_loop(loop)
loop.set_debug(True)
return loop.run_until_complete(main)
finally:
try:
cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
set_event_loop(None)
loop.close()
def main(argv: Sequence[str] = None):
parser = ArgumentParser(NAME)
parser.add_argument('-v', '--verbose', action='count',
default=0,
help='enable debug messages')
commands = parser.add_subparsers(help='commands')
cmd_server = commands.add_parser('serve', help='launch server')
cmd_server.set_defaults(command=serve)
cmd_executor = commands.add_parser('exec', help='execute command')
cmd_executor.set_defaults(command=execute)
setup_serve_arguments(cmd_server)
setup_exec_arguments(cmd_executor)
params = parser.parse_args()
level = logging.DEBUG if params.verbose > 0 else logging.WARN
logging.basicConfig(level=level, stream=sys.stderr)
debug('args=%s', params)
return params.command(params)
if __name__ == '__main__':
sys.exit(main(sys.argv))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment