Skip to content

Instantly share code, notes, and snippets.

@mentha
Last active November 27, 2022 11:12
Show Gist options
  • Save mentha/69bef8d1f53e627b590c8a80c05e285d to your computer and use it in GitHub Desktop.
Save mentha/69bef8d1f53e627b590c8a80c05e285d to your computer and use it in GitHub Desktop.
proxy protocol v2 universal adapter
#!/usr/bin/env python3
from argparse import ArgumentParser
from collections import namedtuple
from contextlib import AsyncExitStack
from ipaddress import IPv4Address, IPv6Address, IPv4Network, IPv6Network
from signal import signal, SIGHUP, SIGINT, SIGTERM
from socket import socket, AF_INET, AF_INET6, SOCK_STREAM, IPPROTO_IP, IP_TRANSPARENT, SHUT_WR
import asyncio as aio
import logging
import struct
import subprocess as sp
logger = logging.getLogger(__name__)
class Ppua:
TABLE_ID = 17241
LOOPBACK_IP = '169.254.118.65'
def __init__(self):
self.opt = None
self.servers = None
class Forwarding(namedtuple('Forwarding', ['local', 'host', 'port'])):
def __new__(cls, s):
u, r = s.split(':', 1)
h, p = r.rsplit(':', 1)
return super().__new__(cls, u, h, p)
@property
def dest(self):
return self.host, self.port
def parse_args(self):
a = ArgumentParser(description='Proxy Protocol v2 universal adapter')
a.add_argument('--verbose', '-v', action='count', default=0, help='increase verbosity')
a.add_argument('--forward', '-L', dest='forward', metavar='unix-sock-path:host:port', action='append',
type=self.Forwarding, default=[], help='add forwarding, use "localhost" as host to forward to local ports bound to all addresses')
a.add_argument('--allow-local', action='store_true', help='allow proxy from local ip')
a.add_argument('--allow-private', action='store_true', help='allow proxy from non globally routable ip')
a.add_argument('--auto-route', '-a', action='store_true', help='set up policy routing automatically')
a.add_argument('--auto-route-table-id', action='store', default='17241', help='routing table id')
a.add_argument('--auto-route-ip4', action='store', default='169.254.118.65', help='special loopback ip')
a.add_argument('--auto-route-ip6', action='store', default='fd35:ebbf:d8a6:6768:8c6c:ad1b:05d9:32b2', help='special loopback ip')
self.opt = a.parse_args()
def run_checked(self, *cmd):
logger.debug('running %r', cmd)
sp.run(cmd, stdin=sp.DEVNULL, check=True)
def run_nocheck(self, *cmd):
logger.debug('running %r', cmd)
return sp.run(cmd, stdin=sp.DEVNULL, stdout=sp.DEVNULL, stderr=sp.DEVNULL, check=False).returncode
def route_setup(self):
self.route_cleanup()
self.run_checked('ip', '-4', 'address', 'add', f'{self.opt.auto_route_ip4}/32', 'dev', 'lo')
self.run_checked('ip', '-6', 'address', 'add', f'{self.opt.auto_route_ip6}/128', 'dev', 'lo')
self.run_checked('ip', '-4', 'route', 'add', '0.0.0.0/0', 'table', self.opt.auto_route_table_id, 'dev', 'lo')
self.run_checked('ip', '-6', 'route', 'add', '::/0', 'table', self.opt.auto_route_table_id, 'dev', 'lo')
self.run_checked('ip', '-4', 'rule', 'add', 'from', self.opt.auto_route_ip4, 'lookup', self.opt.auto_route_table_id)
self.run_checked('ip', '-6', 'rule', 'add', 'from', self.opt.auto_route_ip6, 'lookup', self.opt.auto_route_table_id)
def route_cleanup(self):
while self.run_nocheck('ip', '-4', 'address', 'del', f'{self.opt.auto_route_ip4}/32', 'dev', 'lo') == 0:
pass
while self.run_nocheck('ip', '-6', 'address', 'del', f'{self.opt.auto_route_ip6}/128', 'dev', 'lo') == 0:
pass
self.run_nocheck('ip', '-4', 'route', 'flush', 'table', self.opt.auto_route_table_id)
self.run_nocheck('ip', '-6', 'route', 'flush', 'table', self.opt.auto_route_table_id)
while self.run_nocheck('ip', '-4', 'rule', 'del', 'from', self.opt.auto_route_ip4, 'lookup', self.opt.auto_route_table_id) == 0:
pass
while self.run_nocheck('ip', '-6', 'rule', 'del', 'from', self.opt.auto_route_ip6, 'lookup', self.opt.auto_route_table_id) == 0:
pass
def signal_setup(self):
def h(*a):
raise SystemExit(0)
for s in (SIGHUP, SIGINT, SIGTERM):
signal(s, h)
async def main(self):
self.parse_args()
v = min(self.opt.verbose, 2)
logging.basicConfig(level=(
logging.WARN,
logging.INFO,
logging.DEBUG,
)[v])
self.signal_setup()
await self.create_servers()
try:
if self.opt.auto_route:
self.route_setup()
await self.serve_forever()
finally:
if self.opt.auto_route:
self.route_cleanup()
def run(self):
aio.run(self.main())
async def create_servers(self):
l = []
for f in self.opt.forward:
async def handle(r, w, f=f):
await self.handle_proxy(r, w, forwarding=f)
s = aio.start_unix_server(handle, f.local)
l.append(s)
if len(l) == 0:
raise RuntimeError('no forwardings specified')
self.servers = await aio.gather(*l)
async def serve_forever(self):
l = []
async with AsyncExitStack() as es:
for s in self.servers:
await es.enter_async_context(s)
l.append(s.serve_forever())
await aio.gather(*l)
async def open_connection_as(self, host, port, /, family, local_addr, *a, **ka):
loop = aio.get_running_loop()
s = socket(family, SOCK_STREAM)
try:
s.setsockopt(IPPROTO_IP, IP_TRANSPARENT, 1)
s.bind(local_addr)
await loop.sock_connect(s, (host, port))
ka['sock'] = s
s = None
return await aio.open_connection(*a, **ka)
finally:
if s:
s.close()
private4 = (
IPv4Network('0.0.0.0/8'),
IPv4Network('10.0.0.0/8'),
IPv4Network('100.64.0.0/10'),
IPv4Network('127.0.0.0/8'),
IPv4Network('169.254.0.0/16'),
IPv4Network('172.16.0.0/12'),
IPv4Network('192.0.0.0/24'),
IPv4Network('192.0.2.0/24'),
IPv4Network('192.88.99.0/24'),
IPv4Network('192.168.0.0/16'),
IPv4Network('198.18.0.0/15'),
IPv4Network('198.51.100.0/24'),
IPv4Network('203.0.113.0/24'),
IPv4Network('224.0.0.0/3'),
)
global6 = IPv6Network('2000::/3')
local4 = IPv4Network('127.0.0.0/8')
local6 = IPv6Network('::1/128')
def check_acl(self, family, srcip):
if not self.opt.allow_private:
if family == AF_INET:
for n in self.private4:
if srcip in n:
return False
elif family == AF_INET6:
if srcip not in self.global6:
return False
if not self.opt.allow_local:
if ((family == AF_INET and srcip in self.local4) or
(family == AF_INET6 and srcip in self.local6)):
return False
return True
async def handle_proxy(self, reader, writer, forwarding):
cr, cw = None, None
srcip = None
src = None
established = False
try:
try:
if (await reader.readexactly(12) != b'\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a' or
await reader.readexactly(1) != b'\x21'):
raise RuntimeError('protocol error')
proto = await reader.readexactly(1)
family = None
if proto == b'\x11':
family = AF_INET
elif proto == b'\x21':
family = AF_INET6
else:
raise RuntimeError('unsupported protocol')
sz = await reader.readexactly(2)
sz = struct.unpack('>H', sz)[0]
rest = await reader.readexactly(sz)
if family == AF_INET:
srcip = IPv4Address(rest[0:4])
src = (
srcip.compressed,
struct.unpack('>H', rest[8:10])[0]
)
elif family == AF_INET6:
srcip = IPv6Address(rest[0:16])
src = (
srcip.compressed,
struct.unpack('>H', rest[32:34])[0]
)
if not self.check_acl(family, srcip):
raise RuntimeError('rejected')
fhost = forwarding.host
fport = forwarding.port
if self.opt.auto_route and fhost == 'localhost':
if family == AF_INET:
fhost = self.opt.auto_route_ip4
elif family == AF_INET6:
fhost = self.opt.auto_route_ip6
logger.debug('proxy from %r to %r (%r)', src, forwarding.dest, (fhost, fport))
cr, cw = await self.open_connection_as(fhost, fport, family=family, local_addr=src)
except Exception as e:
logger.error('as %s to %s error: %s', src, forwarding.dest, e)
return
logger.info('as %s to %s established', src, forwarding.dest)
established = True
await aio.gather(
self.proxy_single(cr, writer),
self.proxy_single(reader, cw)
)
finally:
writer.close()
await writer.wait_closed()
if cw:
cw.close()
await cw.wait_closed()
if established:
logger.info('as %s to %s closed', src, forwarding.dest)
async def proxy_single(self, r, w):
while True:
b = await r.read(0x10000)
if b == b'':
w.get_extra_info('socket').shutdown(SHUT_WR)
return
w.write(b)
await w.drain()
if __name__ == '__main__':
Ppua().run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment