Skip to content

Instantly share code, notes, and snippets.

@mentha
Last active August 4, 2023 13:20
Show Gist options
  • Save mentha/1c78106dc3639867bc6c2f19d1f714c5 to your computer and use it in GitHub Desktop.
Save mentha/1c78106dc3639867bc6c2f19d1f714c5 to your computer and use it in GitHub Desktop.
caching http proxy
#!/usr/bin/env python3
from argparse import Action, ArgumentParser
from contextlib import suppress
from fcntl import flock, LOCK_SH
from glob import glob
from signal import signal, SIGHUP, SIGINT, SIGTERM
from tempfile import TemporaryDirectory
import json
import os
import re
import shlex
import shutil
import subprocess as sp
import sys
class IncludeAction(Action):
def __call__(self, parser, namespace, values, option_string):
dest = getattr(namespace, self.dest) or 0
if dest > 0:
nv = []
for v in values:
nv.extend(glob(v, recursive=True, include_hidden=True))
values = nv
setattr(namespace, self.dest, dest + 1)
try:
for v in values:
args = []
with open(v) as f:
for l in f:
args.extend(shlex.split(l, comments=True))
parser.parse_args(args, namespace)
finally:
setattr(namespace, self.dest, dest)
class CacheProxy:
@classmethod
def addrport(self, s):
m = re.match(r'^(?:\[?(\d+\.\d+\.\d+\.\d+|[0-9a-fA-F:]+)\]?:)?(\d+)$', s)
if m is None:
raise TypeError()
return (m.group(1) or '::', int(m.group(2)))
@classmethod
def httpserv(self, s):
if s.startswith('http://'):
s = s[7:]
m = re.match(r'^([^\s/]+):(\d+)$', s)
if m is None:
raise TypeError()
a, p = m.groups()
if a.startswith('['):
a = a[1:-1]
return a, p
@classmethod
def squidpath(self, s):
if re.match(r'^\S+$', s) is None:
raise TypeError()
if s.startswith('/'):
return s
return os.path.abspath(s)
@classmethod
def cmdline(self, s):
l = shlex.split(s)
for f in l:
if re.match(r'''^[^\s'"]+$''', f) is None:
raise TypeError()
return ' '.join(l)
@classmethod
def main(self):
a = ArgumentParser(description='Caching proxy server')
a.add_argument('--include', '-i', metavar='FILE', action=IncludeAction, nargs='+',
help='include arguments from specified file')
a.add_argument('--no-sandbox', action='store_true', help='do not spawn processes in bwrap')
a.add_argument('--access-log', metavar='DEST',
help='access log method, stdout, syslog or file path')
a.add_argument('--store-log', metavar='DEST',
help='cache store log method, stdout, syslog or file path')
a.add_argument('--invisible', action='store_true', help='avoid modifying http requests unless necessary')
a.add_argument('--upstream-http', metavar='PROXY', type=self.httpserv,
help='forward requests through upstream http proxy')
a.add_argument('--url-rewrite', metavar=('REGEX', 'DEST[#OPTIONS=VALUE]*'), nargs=2, action='append', default=[],
help=r'set up rewrite rule, using \N in DEST as references, optionally set rewrite [method] and [term]inal flag')
a.add_argument('--url-rewrite-program', metavar='CMDLINE', type=self.cmdline,
help='enable url rewrite using specified program')
a.add_argument('--http-proxy', '-l', metavar='ADDR:PORT', type=self.addrport,
action='append', default=[], help='start http proxy on specified address')
a.add_argument('--redir-proxy', metavar='ADDR:PORT', type=self.addrport,
action='append', default=[], help='start ip intercept proxy on specified address')
a.add_argument('--tproxy-proxy', metavar='ADDR:PORT', type=self.addrport,
action='append', default=[], help='start tproxy on specified address')
a.add_argument('--cache-tls', action='store_true', help='enable tls decrypting and caching')
a.add_argument('--cache-tls-auto', action='store_true',
help='create self signed ca stored in cache dir or specified path if it does not exist')
a.add_argument('--cache-tls-cacert', metavar='PATH', type=self.squidpath, help='tls ca cert file')
a.add_argument('--cache-tls-cakey', metavar='PATH', type=self.squidpath, help='tls ca key file')
a.add_argument('--cache-tls-sslcrtd', metavar='CMDLINE', type=self.cmdline,
help='override sslcrtd_program')
a.add_argument('--cache-dir', '-d', metavar='PATH', type=self.squidpath, help='persistent cache dir')
a.add_argument('--cache-size', '-s', metavar='SIZE', type=int, default=1024,
help='persistent cache size in MiB')
self(**a.parse_args().__dict__).run()
def __init__(self, **params):
self.__dict__.update(params)
@staticmethod
def handle_sig(*a):
sys.exit(0)
@staticmethod
def reprip(s):
if ':' in s:
s = f'[{s}]'
return s
def run(self):
for sig in (SIGHUP, SIGINT, SIGTERM):
signal(sig, self.handle_sig)
with TemporaryDirectory(prefix='cacheproxy-') as td:
tdfd = os.open(td, os.O_RDONLY | os.O_DIRECTORY)
flock(tdfd, LOCK_SH)
try:
self.run_with_td(td)
finally:
os.close(tdfd)
def run_with_td(self, td):
if self.cache_dir is None:
self.cache_dir = os.path.join(td, 'cache')
squidcache = os.path.join(self.cache_dir, 'aufs')
os.makedirs(squidcache, exist_ok=True)
sandbox = ['bwrap',
'--unshare-all',
'--share-net',
'--hostname', 'cacheproxy',
'--chdir', '/',
'--ro-bind', '/', '/',
'--proc', '/proc',
'--dev', '/dev',
'--die-with-parent',
'--bind', squidcache, squidcache]
if self.cache_tls_auto or self.cache_tls_cacert is not None:
self.cache_tls = True
if self.cache_tls_auto:
if self.cache_tls_cacert is None:
self.cache_tls_cacert = os.path.join(self.cache_dir, 'ca.crt')
self.cache_tls_cakey = os.path.join(self.cache_dir, 'ca.key')
if self.cache_tls_auto and not os.path.exists(self.cache_tls_cacert):
sp.run(['openssl', 'req',
'-new',
'-newkey', 'rsa:2048',
'-sha256',
'-days', '3650',
'-nodes',
'-x509',
'-extensions', 'v3_ca',
'-subj', '/CN=cacheproxy/',
'-keyout', self.cache_tls_cacert,
'-out', self.cache_tls_cakey or self.cache_tls_cacert,
], stdin=sp.DEVNULL, check=True)
squidconf = os.path.join(td, 'squid.conf')
with open(squidconf, 'w') as cfg:
cfg.write('shutdown_lifetime 1 seconds\n'
'visible_hostname cacheproxy\n'
'pid_filename none\n'
'http_access deny manager\n'
'http_access allow all\n'
'cache_log /dev/stderr\n'
f'maximum_object_size {self.cache_size} MB\n'
f'cache_dir aufs {squidcache} {self.cache_size} 256 256\n')
for logtype, logmethod in (
('access_log', self.access_log),
('cache_store_log', self.store_log),
):
if logmethod is None:
cfg.write(f'{logtype} none\n')
elif logmethod == 'stdout':
cfg.write(f'{logtype} stdio:/dev/stdout logformat=squid rotate=0\n')
elif logmethod == 'syslog':
cfg.write(f'{logtype} syslog:daemon.info logformat=squid rotate=0\n')
else:
logfile = self.squidpath(logmethod)
with open(logfile, 'a') as f:
pass
cfg.write(f'{logtype} stdio:{logfile} logformat=squid rotate=0\n')
sandbox += ['--bind', logfile, logfile]
if self.invisible:
cfg.write('via off\n'
'forwarded_for transparent\n')
if self.upstream_http is not None:
cfg.write(f'cache_peer {self.upstream_http[0]} parent {self.upstream_http[1]} 0 default\n'
'never_direct allow all\n')
if len(self.url_rewrite) > 0:
if self.url_rewrite_program is not None:
raise RuntimeError('cannot add custom rules while using external url rewrite program')
rulesjson = os.path.join(td, 'rewrite.json')
with open(rulesjson, 'w') as rulesf:
rules = []
for p, d in reversed(self.url_rewrite):
d, *opts = d.split('#')
rule = {'pattern': p, 'dest': d}
for o in opts:
k = o
v = True
if '=' in o:
k, v = o.split('=', 1)
if k == 'method':
rule['method'] = v
elif k == 'term':
rule['terminal'] = v
else:
raise RuntimeError(f'unknown rewrite option {k}')
rules.append(rule)
json.dump(rules, rulesf)
self.url_rewrite_program = shlex.join([sys.executable,
os.path.abspath(shutil.which(sys.argv[0])),
'--internal-as-rewriter', rulesjson])
if self.url_rewrite_program is not None:
cfg.write(f'url_rewrite_program {self.url_rewrite_program}\n')
if len(self.http_proxy) < 1:
raise ValueError('must be at least one http proxy')
tlsparm = ''
if self.cache_tls:
tlsparm = ' ssl-bump generate-host-certificates=on'
if self.cache_tls_cakey is None:
tlsparm += ' tls-cert=' + self.cache_tls_cacert
else:
combined = os.path.join(td, 'ca.pem')
with open(combined, 'wb') as caout:
for infile in (self.cache_tls_cakey, self.cache_tls_cacert):
with open(infile, 'rb') as cain:
caout.write(cain.read())
tlsparm += ' tls-cert=' + combined
for h, p in self.http_proxy:
cfg.write(f'http_port {self.reprip(h)}:{p}{tlsparm}\n')
for h, p in self.redir_proxy:
cfg.write(f'http_port {self.reprip(h)}:{p} intercept{tlsparm}\n')
for h, p in self.tproxy_proxy:
cfg.write(f'http_port {self.reprip(h)}:{p} tproxy{tlsparm}\n')
if self.cache_tls:
sslcrtd = self.cache_tls_sslcrtd
if sslcrtd is None:
certgen = None
for pattern in (
'/usr/lib*/squid/security_file_certgen',
'/usr/lib*/*/security_file_certgen',
'/usr/**/security_file_certgen',
):
r = glob(pattern, recursive=True, include_hidden=True)
if len(r) > 0:
certgen = r[0]
break
if not certgen:
raise RuntimeError('security_file_certgen not found in /usr')
sslcrtd = certgen
cfg.write(f'sslcrtd_program {sslcrtd}\n'
'acl step1 at_step SslBump1\n'
'ssl_bump peek step1\n'
'ssl_bump bump all\n')
if self.no_sandbox:
sandbox = []
try:
sp.run(sandbox + ['squid', '-Nzf', squidconf], stdin=sp.DEVNULL, stdout=sp.PIPE, stderr=sp.STDOUT, universal_newlines=True, check=True)
except sp.CalledProcessError as e:
sys.stderr.write(e.stdout)
raise
with sp.Popen(sandbox + ['squid', '-Nf', squidconf], stdin=sp.DEVNULL) as p:
try:
p.wait()
sys.exit(1)
finally:
p.terminate()
class Rewriter:
def __init__(self, rulepath):
self.rules = []
with open(rulepath) as f:
for r in json.load(f):
self.rules.append((re.compile(r['pattern']), r['dest'], r.get('method', 'rewrite'), r.get('terminal', False)))
def main(self):
while True:
l = sys.stdin.readline()
if not l:
break
m = re.match(r'^(?:(\d+)\s+)?(\S+)(?:\s.*)?$', l)
if m is None:
break
channel, url = m.groups()
r = 'ERR'
with suppress(Exception):
r = self.handle(url)
if channel:
sys.stdout.write(channel)
sys.stdout.write(' ')
sys.stdout.write(r)
sys.stdout.write('\n')
sys.stdout.flush()
def handle(self, url):
oldurl = url
method = None
while True:
newurl, newmethod, term = self.rewrite(url)
if newurl is None or newurl == url:
break
url, method = newurl, newmethod
if term:
break
if method is None or oldurl == url:
return 'OK'
if method == 'rewrite':
return f'OK rewrite-url="{url}"'
if re.match(r'^30\d$', method):
return f'OK status={method} url="{url}"'
raise RuntimeError(f'unsupported method {method}')
def rewrite(self, url):
for p, dest, method, terminal in self.rules:
m = p.search(url)
if m is None:
continue
desturl = re.sub(r'\\(\d)', lambda sm: m.group(int(sm.group(1))), dest)
return desturl, method, terminal
return None, None, True
if __name__ == '__main__':
def start():
if len(sys.argv) >= 2:
if sys.argv[1] == '--internal-as-rewriter':
Rewriter(sys.argv[2]).main()
CacheProxy.main()
start()
--url-rewrite
'/alpine/(MIRRORS.txt|last-updated|(edge|latest-stable|v\d+\.\d+)/((community|main|testing)/[\w_-]+/(APKINDEX.tar\.*|[^/]*\.apk)|releases/[\w_-]+/.*))$'
'mirror:alpine:\1'
--url-rewrite '^mirror:alpine:(.*)$' 'http://dl-cdn.alpinelinux.org/alpine/\1#method=307#term'
--url-rewrite
'/[Aa]rch(?:[Ll]inux)?/(lastsync|lastupdate|iso/(latest|\d+\.\d{2}\.\d{2})/arch/(boot|x86_64|version)|(community|core|extra|multilib|testing)/os/x86_64/[^/]+\.(pkg\.tar.*|db(\.tar.*)?|files(\.tar.*)?)(\.sig)?)'
'mirror:archlinux:\1'
--url-rewrite '^mirror:archlinux:(.*)$' 'http://geo.mirror.pkgbuild.com/\1#method=307#term'
--url-rewrite
'/(debian(?:|-security|-backports))/(dists/[^/]+/(ChangeLog|InRelease|Release(\.gpg)?|((contrib|main|non-free(-firmware)?)/.*))|pool/(contrib|main|non-free(-firmware)?)/\w+/[^/]+/[^/]+\.(dsc|deb))$'
'mirror:\1:\2'
--url-rewrite '^mirror:(debian[^:]*):(.*)$' 'http://cdn-fastly.deb.debian.org/\1/\2#method=307#term'
--url-rewrite
'/ubuntu/(dists/[^/]+/(Contents-[^/]+\.gz|InRelease|Release(\.gpg)?|((main|multiverse|restricted|universe)/.*))|pool/(main|multiverse|restricted|universe)/\w+/[^/]+/[^/]+\.(dsc|deb))$'
'mirror:ubuntu:\1'
--url-rewrite '^mirror:ubuntu:(.*)$' 'http://archive.ubuntu.com/ubuntu/\1#method=307#term'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment