Skip to content

Instantly share code, notes, and snippets.

@ephemient
Forked from dirkakrid/nspawn-enter
Last active January 22, 2019 03:23
Show Gist options
  • Save ephemient/5351b1afa681ca67823fe2e11190e721 to your computer and use it in GitHub Desktop.
Save ephemient/5351b1afa681ca67823fe2e11190e721 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import contextlib
import ctypes
import dbus
import errno
import fcntl
import io
import os
import pickle
import pty
import select
import signal
import socket
import struct
import sys
import termios
import time
import tty
class Process(object):
def __init__(self, libc, pid):
self._libc = libc
self._pid = pid
def namespaces(self, names=None, restore=True):
if names is None:
names = os.listdir('/proc/{0}/ns/'.format(self._pid))
fds = {name: None for name in names}
with contextlib.ExitStack() as stack:
for name in names:
fd = os.open('/proc/{0}/ns/{1}'.format(self._pid, name),
os.O_RDONLY | os.O_CLOEXEC)
stack.callback(os.close, fd)
fds[name] = fd
context = stack.pop_all()
return Namespace(self._libc, fds, context, restore=restore)
class Namespace(contextlib.ContextDecorator):
def __init__(self, libc, fds, context, restore=True):
self._libc = libc
self._fds = fds
self._context = context
self._restore = restore
self._orig = None
self._nested = False
def _enterns(self):
for name, fd in self._fds.items():
if self._libc.setns(fd, 0) != 0:
rc = ctypes.get_errno()
raise OSError(rc, errno.errorcode[rc], name)
def __enter__(self):
try:
pwd = os.getcwdb()
if self._restore:
self._orig = Process(self._libc, os.getpid()).namespaces(
self._fds.keys(), restore=False)
self._orig._nested = True
if 'mnt' in self._fds:
stack = contextlib.ExitStack()
cwd = os.open('.', os.O_CLOEXEC | os.O_DIRECTORY)
stack.callback(os.close, cwd)
stack.callback(os.fchdir, cwd)
stack.callback(os.chroot, '/')
stack.push(self._orig._context)
self._orig._context = stack
self._enterns()
if not self._nested and 'mnt' in self._fds:
os.chroot('/')
try:
os.chdir(pwd)
except:
os.chdir('/')
return self
except:
if self.__exit__(*sys.exc_info()):
pass
else:
raise
finally:
self._context.close()
def __exit__(self, *exc):
if self._orig:
with self._orig:
pass
class _Looper(object):
def __init__(self, master, fd0, fd1, pid):
self._master, self._fd0, self._fd1, self._pid = master, fd0, fd1, pid
self._escapes, self._escape_time = 3 if os.isatty(fd0) else 0, 1
def loop(self):
rlist = [self._master]
if self._fd0 is not None:
rlist.append(self._fd0)
escapes, escape_time = 0, 0
with contextlib.suppress(OSError):
while True:
rfds, wfds, xfds = select.select(rlist, (), ())
if self._master in rfds:
data = os.read(self._master, io.DEFAULT_BUFFER_SIZE)
while data:
count = os.write(self._fd1, data)
if count:
data = data[count:]
if self._fd0 in rfds:
data = os.read(self._fd0, io.DEFAULT_BUFFER_SIZE)
if data:
while data:
count = os.write(self._master, data)
if self._escapes:
new_escapes = data[:count].count(
b'\x1d') # <ESC>]
if new_escapes:
now = time.monotonic()
if escape_time + self._escape_time < now:
escape_time = now
escapes = new_escapes
else:
escapes += new_escapes
if escapes >= self._escapes:
raise OSError()
if count:
data = data[count:]
else:
rlist.remove(self._fd0)
def wait(self):
(pid, status) = os.waitpid(self._pid, os.WNOHANG)
if pid and os.WIFEXITED(status):
return os.WEXITSTATUS(status)
if pid and os.WIFSIGNALED(status):
return 128 | os.WTERMSIG(status)
return status
class Pty(object):
def __init__(self, callback, *args, **kwargs):
assert callable(callback)
self._callback, self._args, self._kwargs = callback, args, kwargs
self._context = None
def _fds(self, stdin=sys.stdin, stdout=sys.stdout):
with contextlib.ExitStack() as stack:
fd0 = stdin.fileno()
if not os.isatty(fd0):
try:
fd0 = os.open('/dev/tty', os.O_RDWR)
except OSError:
pass
else:
stack.callback(os.close, fd0)
fd1 = stdout.fileno()
if os.isatty(fd0) and not os.isatty(fd1):
fd1 = fd0
elif not os.isatty(fd1):
try:
fd1 = os.open('/dev/tty', os.O_RDWR)
except OSError:
pass
else:
stack.callback(os.close, fd1)
return fd0, fd1, stack.pop_all()
def __enter__(self):
assert not self._context
(s1, s2) = socket.socketpair()
with contextlib.ExitStack() as stack:
stack.callback(s1.close)
with contextlib.ExitStack() as stack2:
stack2.callback(s2.close)
os.set_inheritable(s1.fileno(), False)
os.set_inheritable(s2.fileno(), False)
(pid, master) = pty.fork()
if pid == 0:
stack.close()
try:
select.select((s2, ), (), ())
(self._callback)(*self._args, *self._kwargs)
except:
try:
ex = sys.exc_info()[0]
with s2.makefile('wb') as f:
pickle.dump(ex, f)
finally:
os._exit(-1)
os._exit(0)
with contextlib.ExitStack() as stack2:
die = contextlib.ExitStack()
stack2.push(die)
die.callback(os.kill, pid, signal.SIGKILL)
fd0, fd1, stack3 = self._fds()
stack2.push(stack3)
if os.isatty(fd0):
tc = termios.tcgetattr(fd0)
stack2.callback(termios.tcsetattr, fd0, termios.TCSANOW,
tc)
tty.setraw(fd0, when=termios.TCSANOW)
def winch(signo, frame):
tc = fcntl.ioctl(fd1, termios.TIOCGWINSZ,
struct.pack('HHHH', 0, 0, 0, 0))
fcntl.ioctl(master, termios.TIOCSWINSZ, tc)
if os.isatty(fd1):
stack2.callback(signal.signal, signal.SIGWINCH,
signal.getsignal(signal.SIGWINCH))
signal.signal(signal.SIGWINCH, winch)
winch(0, None)
s1.shutdown(socket.SHUT_WR)
with s1.makefile('rb') as f:
try:
ex = pickle.load(f)
except EOFError as e:
ex = None
if ex:
raise ex
looper = _Looper(master, fd0, fd1, pid)
self._context = stack2.pop_all()
die.pop_all()
return looper
def __exit__(self, *exc):
self._context.close()
self._context = None
def get_pid(cgroup):
return int(
open('/sys/fs/cgroup/{0}/tasks'.format(cgroup)).readline().strip())
def get_env(pid):
try:
with open('/proc/{0}/environ', 'rb') as f:
env = {}
for b in f.read().split(b'\0'):
(key, _, value) = b.partition('=')
env[key] = value
return env
except:
return {}
def main():
parser = argparse.ArgumentParser(
description='Join a running Anbox instance')
parser.add_argument(
'--tty',
'-t',
action='store_true',
help='force pseudo-terminal allocation')
parser.add_argument(
'--notty',
'-T',
action='store_false',
help='disable pseudo-terminal allocation')
parser.add_argument(
'--ns', '-n', nargs='*', help='namespaces to join (default all)')
parser.add_argument(
'--cgroup', help='container path', default='systemd/lxc/default')
parser.add_argument(
'--pid', '-p', help='process id in container', type=int)
parser.add_argument(
'--preserve-env',
'-E',
help='use environment',
action='store_true',
dest='env')
parser.add_argument(
'--no-preserve-env',
'-e',
help='do not use environment',
action='store_false',
dest='env')
parser.add_argument('command', nargs='?', default='/system/bin/sh')
parser.add_argument('args', nargs='*')
args = parser.parse_args()
cmd = [args.command] + args.args
do_tty = args.tty or args.notty and os.isatty(sys.stdin.fileno())
libc = ctypes.CDLL('libc.so.6', use_errno=True)
pid = args.pid or get_pid(args.cgroup)
env = os.environ if args.env else get_env(pid)
with Process(libc, pid).namespaces(names=args.ns or None) as ns:
if do_tty:
with Pty(os.execvpe, args.command, cmd, env) as looper:
looper.loop()
return looper.wait()
else:
os.execvpe(args.command, cmd, env)
if __name__ == '__main__':
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment