Last active
March 21, 2017 19:59
-
-
Save thehesiod/5dbd7f2bffbe0b850980e865f5649338 to your computer and use it in GitHub Desktop.
Python Requests HTTPS leak
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
import multiprocessing | |
import setproctitle | |
import os | |
import threading | |
import tracemalloc | |
import linecache | |
import traceback | |
import gc | |
import logging | |
import time | |
import psutil | |
import sys | |
import socket | |
from enum import IntEnum | |
from calib.manager import CertificateManager # you'll need this fix for py3: git+git://github.com/thehesiod/pyca.git@fix-py3#egg=calib | |
HOSTNAME = socket.gethostname() | |
CERT_ROOT = '/tmp/certs' | |
CA_CERT_PATH = os.path.join(CERT_ROOT, 'certs/ca.cert.pem') | |
CA_KEY_PATH = os.path.join(CERT_ROOT, 'private/ca.key.pem') | |
SERVER_CERT_PATH = os.path.join(CERT_ROOT, 'certs/{}.cert.pem'.format(HOSTNAME)) | |
SERVER_KEY_PATH = os.path.join(CERT_ROOT, 'private/{}.key.pem'.format(HOSTNAME)) | |
# HA_PROXY cert ordering: server_key, server_cert, ca_cert | |
COMBINED_CERT_PATH = os.path.join(CERT_ROOT, "haproxy.pem") | |
COMBINED_CERT_FILES = [SERVER_KEY_PATH, SERVER_CERT_PATH, CA_CERT_PATH] | |
SSL_PORT = 4443 | |
class ServerType(IntEnum): | |
EXTERNAL_HTTPS = 0 | |
LOCAL_HTTPS = 2 | |
SERVER_TYPE = ServerType.EXTERNAL_HTTPS | |
if SERVER_TYPE == ServerType.EXTERNAL_HTTPS: | |
URL = 'https://google.com' | |
VERIFY = None # none needed | |
num_fetchers = 2 | |
elif SERVER_TYPE == ServerType.LOCAL_HTTPS: | |
URL = 'https://{}:{}/'.format(HOSTNAME, SSL_PORT) | |
VERIFY = CA_CERT_PATH | |
num_fetchers = 1 | |
else: | |
assert False | |
_req_num = multiprocessing.Value('i', 0) | |
def memleak_checker(pid): | |
global _req_num | |
setproctitle.setproctitle('leak reporter') | |
_proc = psutil.Process(pid) | |
suffix = time.strftime("%Y%m%d%H%M%S", time.localtime()) | |
mprofile_output = "mprofile_%s.dat" % suffix | |
mprofile_fd_output = "fprofile_%s.dat" % suffix | |
mprofile_net_output = "nprofile_%s.dat" % suffix | |
files = [mprofile_output, mprofile_fd_output, mprofile_net_output] | |
for fp in files: | |
if os.path.exists(fp): | |
os.unlink(fp) | |
_mem_file = open(mprofile_output, 'w') | |
_fd_file = open(mprofile_fd_output, 'w') | |
_conn_file = open(mprofile_net_output, 'w') | |
while True: | |
if _req_num.value < 30: | |
time.sleep(1) # warm up | |
# instead of calling: memory_profiler.memory_usage(-1, interval=.3, timestamps=True, timeout=100000, stream=f) | |
# we do this manually so we do a timestamp per request...also memory_usage keeps accumulating results for some reason | |
proc_rss = _proc.memory_info()[0] / float(2 ** 20) | |
_mem_file.write("MEM {0:.6f} {1:.4f}\n".format(proc_rss, _req_num.value)) | |
proc_fds = _proc.num_fds() | |
_fd_file.write("MEM {0:.6f} {1:.4f}\n".format(proc_fds, _req_num.value)) | |
proc_conns = _proc.connections() | |
_conn_file.write("MEM {0:.6f} {1:.4f}\n".format(len(proc_conns), _req_num.value)) | |
_mem_file.flush() | |
_fd_file.flush() | |
_conn_file.flush() | |
time.sleep(1) | |
def do_one_req(): | |
global _req_num | |
try: | |
response = requests.head(URL, verify=VERIFY, allow_redirects=False, timeout=1) | |
for r in response.history: | |
r.close() | |
response.close() | |
return response | |
except: | |
traceback.print_exc() | |
finally: | |
_req_num.value += 1 | |
def do_reqs(): | |
while True: | |
do_one_req() | |
print('.', end='', flush=True) | |
_TRACE_FILTERS = ( | |
tracemalloc.Filter(False, "<frozen importlib._bootstrap>", all_frames=True), | |
tracemalloc.Filter(False, tracemalloc.__file__, all_frames=True), | |
tracemalloc.Filter(False, linecache.__file__, all_frames=True) | |
) | |
def tracemalloc_checker(): | |
_first_snap = None | |
LEAK_DUMP_PATH = '/tmp/mem_log.txt' | |
if os.path.exists(LEAK_DUMP_PATH): | |
os.unlink(LEAK_DUMP_PATH) | |
while True: | |
time.sleep(60) | |
start = time.time() | |
if _first_snap is None: | |
gc.collect() | |
_first_snap = tracemalloc.take_snapshot().filter_traces(_TRACE_FILTERS) | |
continue | |
# print memory diff | |
gc.collect() | |
snap = tracemalloc.take_snapshot().filter_traces(_TRACE_FILTERS) | |
top_stats = snap.compare_to(_first_snap, 'traceback') | |
top_stats = sorted(top_stats, key=lambda x: x.size, reverse=True) | |
with open(LEAK_DUMP_PATH, 'w+') as f: | |
f = sys.stdout | |
f.write('===============================' + os.linesep) | |
f.write("[Top 20 differences elapsed: {}]".format(round(time.time() - start)) + os.linesep) | |
for stat in top_stats[:20]: | |
f.write(str(stat) + os.linesep) | |
for line in stat.traceback.format(): | |
f.write('\t' + str(line) + os.linesep) | |
f.write('===============================' + os.linesep) | |
def wait_server_up(): | |
while True: | |
try: | |
do_one_req() | |
return | |
except requests.exceptions.ConnectionError: | |
pass | |
if SERVER_TYPE == ServerType.LOCAL_HTTPS: | |
def ssl_server(): | |
import http.server, ssl | |
setproctitle.setproctitle('server') | |
class SimpleHTTPRequestHandler(http.server.BaseHTTPRequestHandler): | |
protocol_version = "HTTP/1.1" | |
def do_HEAD(self): | |
self.send_response_only(http.HTTPStatus.MOVED_PERMANENTLY, "Moved Permanently") | |
self.send_header("Content-Length", 0) | |
self.send_header("Location", URL) | |
self.send_header("Connection", "close") | |
self.end_headers() | |
server_address = (HOSTNAME, SSL_PORT) | |
httpd = http.server.HTTPServer(server_address, SimpleHTTPRequestHandler) | |
httpd.socket = ssl.wrap_socket(httpd.socket, | |
server_side=True, | |
keyfile=SERVER_KEY_PATH, | |
certfile=SERVER_CERT_PATH, | |
ssl_version=ssl.PROTOCOL_TLSv1) | |
httpd.serve_forever() | |
def create_certs(): | |
# Create certs | |
if not os.path.exists(CERT_ROOT): | |
os.mkdir(CERT_ROOT) | |
mgr = CertificateManager(CERT_ROOT) | |
mgr.init([HOSTNAME]) | |
print("Creating Root Cert, enter {} for commonName".format(HOSTNAME)) | |
mgr.createRootCertificate(noPass=True, keyLength=2048) | |
print("Creating Server Cert, enter {} for commonName".format(HOSTNAME)) | |
mgr.createServerCertificate(HOSTNAME) | |
with open(COMBINED_CERT_PATH, 'w') as proxy_f: | |
for fp in COMBINED_CERT_FILES: | |
with open(fp, 'r') as f: | |
proxy_f.write(f.read() + os.linesep) | |
def periodic_gc(): | |
while True: | |
time.sleep(5) | |
gc.collect() | |
def main(): | |
logging.basicConfig(level=logging.WARNING) | |
setproctitle.setproctitle('leaker') | |
threads = [threading.Thread(target=periodic_gc, daemon=True)] | |
if SERVER_TYPE in {ServerType.LOCAL_HTTPS}: | |
create_certs() | |
server_proc = multiprocessing.Process(target=ssl_server) | |
server_proc.start() | |
wait_server_up() | |
mem_proc = multiprocessing.Process(target=memleak_checker, args=[os.getpid()]) | |
mem_proc.start() | |
# leak_thread = threading.Thread(target=tracemalloc_checker, daemon=True) | |
# tracemalloc.start(15) | |
# leak_thread.start() | |
threads += [threading.Thread(target=do_reqs, daemon=True) for _ in range(num_fetchers)] | |
for t in threads: | |
if not t.is_alive(): | |
t.start() | |
for t in threads: | |
t.join() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment