Skip to content

Instantly share code, notes, and snippets.

@jonashaag
Created April 14, 2022 20:45
Show Gist options
  • Save jonashaag/3773351576fcc56632f285277029865c to your computer and use it in GitHub Desktop.
Save jonashaag/3773351576fcc56632f285277029865c to your computer and use it in GitHub Desktop.
Simple caching Conda proxy
import atexit
import base64
import logging
import os
import pickle
import diskcache
import proxy2
logger = logging.getLogger("conda_proxy")
class DiskCache:
def __init__(self, path, size: int):
self._cache = diskcache.Cache(
path, eviction_policy="least-frequently-used", size_limit=size
)
def _key(self, path: str) -> str:
return base64.b64encode(path.encode()).decode()
def is_cached(self, path: str) -> bool:
return self._key(path) in self._cache
def read_bytes(self, path: str) -> bytes:
return self._cache.get(self._key(path))
def write_bytes(self, path: str, blob: bytes) -> None:
self._cache.add(self._key(path), blob)
def close(self) -> None:
self._cache.close()
class CondaProxyRequestHandler(proxy2.ProxyRequestHandler):
def request_handler(self, req_body):
if self.path in repodata_cache:
logger.info(f"Found {self.path} in cache.")
self.cache_miss = False
return repodata_cache[self.path]
elif packages_cache.is_cached(self.path):
logger.info(f"Found {self.path} in cache.")
self.cache_miss = False
return pickle.loads(packages_cache.read_bytes(self.path))
else:
self.cache_miss = True
def response_handler(self, req_body, res, res_body):
if self.cache_miss:
if self.path.endswith("repodata.json"):
logger.info(f"Adding {self.path} to cache.")
repodata_cache.set(self.path, (res, res_body), expire=repodata_ttl)
else:
logger.info(f"Adding {self.path} to cache.")
packages_cache.write_bytes(self.path, pickle.dumps((res, res_body)))
persistence_path = os.environ.get("CONDA_PROXY_CACHE_PATH", "/tmp/uvproxy")
cache_max_size = int(os.environ.get("CONDA_PROXY_CACHE_SIZE", 1e9))
repodata_ttl = int(os.environ.get("CONDA_PROXY_REPODATA_TTL", 3600))
timeout = int(os.environ.get("CONDA_PROXY_HTTP_TIMEOUT", "300"))
# We need two diferent caches with their own strategies
packages_cache = DiskCache(persistence_path, cache_max_size)
repodata_cache = diskcache.Cache(persistence_path)
atexit.register(packages_cache.close)
atexit.register(repodata_cache.close)
def main():
import http.server
import socket
import socketserver
class ThreadingHTTPServer(socketserver.ThreadingMixIn, http.server.HTTPServer):
address_family = socket.AF_INET6
daemon_threads = True
logging.basicConfig(level="DEBUG")
httpd = ThreadingHTTPServer(("localhost", 8080), CondaProxyRequestHandler)
httpd.serve_forever()
if __name__ == "__main__":
main()
# From: https://github.com/inaz2/proxy2/pull/6
# BSD 3 clause
import http.client
import http.server
import os
import re
import ssl
import string
import threading
import time
import urllib.parse
import OpenSSL
import ssl_wrapper
def join_with_script_dir(path):
return os.path.join(os.path.dirname(os.path.abspath(__file__)), path)
class ProxyRequestHandler(http.server.BaseHTTPRequestHandler):
cakey = join_with_script_dir('ca.key')
cacert = join_with_script_dir('ca.crt')
certkey = join_with_script_dir('cert.key')
certdir = join_with_script_dir('certs/')
timeout = 10
lock = threading.Lock()
def __init__(self, *args, **kwargs):
self.tls = threading.local()
self.tls.conns = {}
super().__init__(*args, **kwargs)
def do_CONNECT(self):
hostname = self.path.split(':')[0]
ippat = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$")
cert_category = "DNS"
if ippat.match(hostname):
cert_category = "IP"
certpath = "%s/%s.crt" % (ssl_wrapper.cert_dir.rstrip('/'), hostname)
with self.lock:
if not os.path.isfile(certpath):
x509_serial = int("%d" % (time.time() * 1000))
valid_time_interval = (0, 60 * 60 * 24 * 365)
cert_request = ssl_wrapper.create_cert_request(ssl_wrapper.cert_key_obj, CN=hostname)
cert = ssl_wrapper.create_certificate(
cert_request, (ssl_wrapper.ca_crt_obj, ssl_wrapper.ca_key_obj), x509_serial,
valid_time_interval,
subject_alt_names=[
string.Template("${category}:${hostname}").substitute(hostname=hostname, category=cert_category)
]
)
with open(certpath, 'wb+') as f:
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert))
self.wfile.write("HTTP/1.1 {} {}\r\n".format(200, 'Connection Established').encode('latin-1'))
self.wfile.write(b'\r\n')
self.connection = ssl.wrap_socket(self.connection,
keyfile=ssl_wrapper.cert_key,
certfile=certpath,
server_side=True)
self.rfile = self.connection.makefile("rb", self.rbufsize)
self.wfile = self.connection.makefile("wb", self.wbufsize)
conntype = self.headers.get('Proxy-Connection', '')
if conntype.lower() != 'close':
self.close_connection = False
def do_GET(self):
content_length = int(self.headers.get('Content-Length', 0))
req_body = self.rfile.read(content_length) if content_length else None
if self.path[0] == '/':
if isinstance(self.connection, ssl.SSLSocket):
self.path = "https://{}{}".format(self.headers['Host'], self.path)
else:
self.path = "http://{}{}".format(self.headers['Host'], self.path)
req_body_modified = self.request_handler(req_body)
if req_body_modified is False:
self.send_error(403)
return
if isinstance(req_body_modified, tuple):
res, res_body = req_body_modified
else:
if req_body_modified is not None:
req_body = req_body_modified
if 'Content-Length' in self.headers:
del self.headers['Content-Length']
self.headers['Content-Length'] = str(len(req_body_modified))
res, res_body = self._make_req(req_body)
if 'Content-Length' not in res.msg:
res.msg['Content-Length'] = str(len(res_body))
setattr(res, 'headers', self.filter_headers(res.msg))
self.wfile.write(f"HTTP/1.1 {res.status} {res.reason}\r\n".encode("ascii"))
for k, v in res.msg.items():
self.send_header(k, v)
self.end_headers()
if res_body:
self.wfile.write(res_body)
self.wfile.flush()
def _make_req(self, req_body):
url = urllib.parse.urlsplit(self.path)
scheme, netloc, path = url.scheme, url.netloc, (url.path + '?' + url.query if url.query else url.path)
assert scheme in ('http', 'https')
origin = (scheme, netloc)
if netloc:
if 'Host' in self.headers:
del self.headers['Host']
self.headers['Host'] = netloc
setattr(self, 'headers', self.filter_headers(self.headers))
# Make connection to upstream
conn = self.tls.conns.get(origin)
if conn is None:
conn = self.tls.conns[origin] = {
"https": http.client.HTTPSConnection,
"http": http.client.HTTPConnection,
}[scheme](netloc, timeout=self.timeout)
try:
conn.request(self.command, path, req_body, dict(self.headers))
res = conn.getresponse()
res_body = res.read()
self.response_handler(req_body, res, res_body)
return res, res_body
except Exception:
self.tls.conns.pop(origin, None)
raise
do_HEAD = do_GET
do_POST = do_GET
do_PUT = do_GET
do_DELETE = do_GET
do_OPTIONS = do_GET
def filter_headers(self, headers):
# http://tools.ietf.org/html/rfc2616#section-13.5.1
hop_by_hop = (
'connection',
'keep-alive',
'proxy-authenticate',
'proxy-authorization',
'te',
'trailers',
'transfer-encoding',
'upgrade'
)
for k in hop_by_hop:
del headers[k]
# accept only supported encodings
if 'Accept-Encoding' in headers:
ae = headers['Accept-Encoding']
filtered_encodings = [x for x in re.split(r',\s*', ae) if x in ('identity', 'gzip', 'x-gzip', 'deflate')]
del headers['Accept-Encoding']
headers['Accept-Encoding'] = ', '.join(filtered_encodings)
return headers
def request_handler(self, req_body):
pass
def response_handler(self, req_body, res, res_body):
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment