Skip to content

Instantly share code, notes, and snippets.

@thehesiod
Last active March 21, 2017 19:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thehesiod/5dbd7f2bffbe0b850980e865f5649338 to your computer and use it in GitHub Desktop.
Save thehesiod/5dbd7f2bffbe0b850980e865f5649338 to your computer and use it in GitHub Desktop.
Python Requests HTTPS leak
import requests
import multiprocessing
import setproctitle
import os
import threading
import tracemalloc
import linecache
import traceback
import gc
import logging
import time
import psutil
import sys
import socket
from enum import IntEnum
from calib.manager import CertificateManager # you'll need this fix for py3: git+git://github.com/thehesiod/pyca.git@fix-py3#egg=calib
HOSTNAME = socket.gethostname()
CERT_ROOT = '/tmp/certs'
CA_CERT_PATH = os.path.join(CERT_ROOT, 'certs/ca.cert.pem')
CA_KEY_PATH = os.path.join(CERT_ROOT, 'private/ca.key.pem')
SERVER_CERT_PATH = os.path.join(CERT_ROOT, 'certs/{}.cert.pem'.format(HOSTNAME))
SERVER_KEY_PATH = os.path.join(CERT_ROOT, 'private/{}.key.pem'.format(HOSTNAME))
# HA_PROXY cert ordering: server_key, server_cert, ca_cert
COMBINED_CERT_PATH = os.path.join(CERT_ROOT, "haproxy.pem")
COMBINED_CERT_FILES = [SERVER_KEY_PATH, SERVER_CERT_PATH, CA_CERT_PATH]
SSL_PORT = 4443
class ServerType(IntEnum):
EXTERNAL_HTTPS = 0
LOCAL_HTTPS = 2
SERVER_TYPE = ServerType.EXTERNAL_HTTPS
if SERVER_TYPE == ServerType.EXTERNAL_HTTPS:
URL = 'https://google.com'
VERIFY = None # none needed
num_fetchers = 2
elif SERVER_TYPE == ServerType.LOCAL_HTTPS:
URL = 'https://{}:{}/'.format(HOSTNAME, SSL_PORT)
VERIFY = CA_CERT_PATH
num_fetchers = 1
else:
assert False
_req_num = multiprocessing.Value('i', 0)
def memleak_checker(pid):
global _req_num
setproctitle.setproctitle('leak reporter')
_proc = psutil.Process(pid)
suffix = time.strftime("%Y%m%d%H%M%S", time.localtime())
mprofile_output = "mprofile_%s.dat" % suffix
mprofile_fd_output = "fprofile_%s.dat" % suffix
mprofile_net_output = "nprofile_%s.dat" % suffix
files = [mprofile_output, mprofile_fd_output, mprofile_net_output]
for fp in files:
if os.path.exists(fp):
os.unlink(fp)
_mem_file = open(mprofile_output, 'w')
_fd_file = open(mprofile_fd_output, 'w')
_conn_file = open(mprofile_net_output, 'w')
while True:
if _req_num.value < 30:
time.sleep(1) # warm up
# instead of calling: memory_profiler.memory_usage(-1, interval=.3, timestamps=True, timeout=100000, stream=f)
# we do this manually so we do a timestamp per request...also memory_usage keeps accumulating results for some reason
proc_rss = _proc.memory_info()[0] / float(2 ** 20)
_mem_file.write("MEM {0:.6f} {1:.4f}\n".format(proc_rss, _req_num.value))
proc_fds = _proc.num_fds()
_fd_file.write("MEM {0:.6f} {1:.4f}\n".format(proc_fds, _req_num.value))
proc_conns = _proc.connections()
_conn_file.write("MEM {0:.6f} {1:.4f}\n".format(len(proc_conns), _req_num.value))
_mem_file.flush()
_fd_file.flush()
_conn_file.flush()
time.sleep(1)
def do_one_req():
global _req_num
try:
response = requests.head(URL, verify=VERIFY, allow_redirects=False, timeout=1)
for r in response.history:
r.close()
response.close()
return response
except:
traceback.print_exc()
finally:
_req_num.value += 1
def do_reqs():
while True:
do_one_req()
print('.', end='', flush=True)
_TRACE_FILTERS = (
tracemalloc.Filter(False, "<frozen importlib._bootstrap>", all_frames=True),
tracemalloc.Filter(False, tracemalloc.__file__, all_frames=True),
tracemalloc.Filter(False, linecache.__file__, all_frames=True)
)
def tracemalloc_checker():
_first_snap = None
LEAK_DUMP_PATH = '/tmp/mem_log.txt'
if os.path.exists(LEAK_DUMP_PATH):
os.unlink(LEAK_DUMP_PATH)
while True:
time.sleep(60)
start = time.time()
if _first_snap is None:
gc.collect()
_first_snap = tracemalloc.take_snapshot().filter_traces(_TRACE_FILTERS)
continue
# print memory diff
gc.collect()
snap = tracemalloc.take_snapshot().filter_traces(_TRACE_FILTERS)
top_stats = snap.compare_to(_first_snap, 'traceback')
top_stats = sorted(top_stats, key=lambda x: x.size, reverse=True)
with open(LEAK_DUMP_PATH, 'w+') as f:
f = sys.stdout
f.write('===============================' + os.linesep)
f.write("[Top 20 differences elapsed: {}]".format(round(time.time() - start)) + os.linesep)
for stat in top_stats[:20]:
f.write(str(stat) + os.linesep)
for line in stat.traceback.format():
f.write('\t' + str(line) + os.linesep)
f.write('===============================' + os.linesep)
def wait_server_up():
while True:
try:
do_one_req()
return
except requests.exceptions.ConnectionError:
pass
if SERVER_TYPE == ServerType.LOCAL_HTTPS:
def ssl_server():
import http.server, ssl
setproctitle.setproctitle('server')
class SimpleHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
protocol_version = "HTTP/1.1"
def do_HEAD(self):
self.send_response_only(http.HTTPStatus.MOVED_PERMANENTLY, "Moved Permanently")
self.send_header("Content-Length", 0)
self.send_header("Location", URL)
self.send_header("Connection", "close")
self.end_headers()
server_address = (HOSTNAME, SSL_PORT)
httpd = http.server.HTTPServer(server_address, SimpleHTTPRequestHandler)
httpd.socket = ssl.wrap_socket(httpd.socket,
server_side=True,
keyfile=SERVER_KEY_PATH,
certfile=SERVER_CERT_PATH,
ssl_version=ssl.PROTOCOL_TLSv1)
httpd.serve_forever()
def create_certs():
# Create certs
if not os.path.exists(CERT_ROOT):
os.mkdir(CERT_ROOT)
mgr = CertificateManager(CERT_ROOT)
mgr.init([HOSTNAME])
print("Creating Root Cert, enter {} for commonName".format(HOSTNAME))
mgr.createRootCertificate(noPass=True, keyLength=2048)
print("Creating Server Cert, enter {} for commonName".format(HOSTNAME))
mgr.createServerCertificate(HOSTNAME)
with open(COMBINED_CERT_PATH, 'w') as proxy_f:
for fp in COMBINED_CERT_FILES:
with open(fp, 'r') as f:
proxy_f.write(f.read() + os.linesep)
def periodic_gc():
while True:
time.sleep(5)
gc.collect()
def main():
logging.basicConfig(level=logging.WARNING)
setproctitle.setproctitle('leaker')
threads = [threading.Thread(target=periodic_gc, daemon=True)]
if SERVER_TYPE in {ServerType.LOCAL_HTTPS}:
create_certs()
server_proc = multiprocessing.Process(target=ssl_server)
server_proc.start()
wait_server_up()
mem_proc = multiprocessing.Process(target=memleak_checker, args=[os.getpid()])
mem_proc.start()
# leak_thread = threading.Thread(target=tracemalloc_checker, daemon=True)
# tracemalloc.start(15)
# leak_thread.start()
threads += [threading.Thread(target=do_reqs, daemon=True) for _ in range(num_fetchers)]
for t in threads:
if not t.is_alive():
t.start()
for t in threads:
t.join()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment