Skip to content

Instantly share code, notes, and snippets.

@thehesiod
Last active April 10, 2017 19:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thehesiod/ef79dd77e2df7a0a7893dfea6325d30a to your computer and use it in GitHub Desktop.
Save thehesiod/ef79dd77e2df7a0a7893dfea6325d30a to your computer and use it in GitHub Desktop.
requests leak checker
import multiprocessing
import setproctitle
import os
import threading
import tracemalloc
import linecache
import ssl
import traceback
from urllib.parse import urlparse
import gc
import logging
import time
import psutil
import sys
import socket
from enum import IntEnum
HOSTNAME = socket.gethostname()
CERT_ROOT = '/tmp/certs'
CA_CERT_PATH = os.path.join(CERT_ROOT, 'certs/ca.cert.pem')
CA_KEY_PATH = os.path.join(CERT_ROOT, 'private/ca.key.pem')
SERVER_CERT_PATH = os.path.join(CERT_ROOT, 'certs/{}.cert.pem'.format(HOSTNAME))
SERVER_KEY_PATH = os.path.join(CERT_ROOT, 'private/{}.key.pem'.format(HOSTNAME))
# HA_PROXY cert ordering: server_key, server_cert, ca_cert
COMBINED_CERT_PATH = os.path.join(CERT_ROOT, "haproxy.pem")
COMBINED_CERT_FILES = [SERVER_KEY_PATH, SERVER_CERT_PATH, CA_CERT_PATH]
SSL_PORT = 4443
class ServerType(IntEnum):
EXTERNAL_HTTPS = 0
LOCAL_HTTPS = 2
RAW_SSL_SOCKET = 3
class ClientType(IntEnum):
REQUESTS = 0
RAW = 1
AIOHTTP = 2
SERVER_TYPE = ServerType.EXTERNAL_HTTPS
CLIENT_TYPE = ClientType.RAW
if CLIENT_TYPE == ClientType.AIOHTTP:
import aiohttp, asyncio
elif CLIENT_TYPE == ClientType.REQUESTS:
import requests
if SERVER_TYPE == ServerType.EXTERNAL_HTTPS:
URL = 'https://google.com'
VERIFY = None # none needed
num_fetchers = 2
elif SERVER_TYPE == ServerType.LOCAL_HTTPS:
URL = 'https://{}:{}/'.format(HOSTNAME, SSL_PORT)
VERIFY = CA_CERT_PATH
num_fetchers = 1
else:
assert False
_req_num = multiprocessing.Value('i', 0)
def memleak_checker(pid):
global _req_num
setproctitle.setproctitle('leak reporter')
_proc = psutil.Process(pid)
suffix = time.strftime("%Y%m%d%H%M%S", time.localtime())
mprofile_output = "mprofile_%s.dat" % suffix
mprofile_fd_output = "fprofile_%s.dat" % suffix
mprofile_net_output = "nprofile_%s.dat" % suffix
files = [mprofile_output, mprofile_fd_output, mprofile_net_output]
for fp in files:
if os.path.exists(fp):
os.unlink(fp)
_mem_file = open(mprofile_output, 'w')
_fd_file = open(mprofile_fd_output, 'w')
_conn_file = open(mprofile_net_output, 'w')
while True:
if _req_num.value < 30:
time.sleep(1) # warm up
# instead of calling: memory_profiler.memory_usage(-1, interval=.3, timestamps=True, timeout=100000, stream=f)
# we do this manually so we do a timestamp per request...also memory_usage keeps accumulating results for some reason
proc_rss = _proc.memory_info()[0] / float(2 ** 20)
_mem_file.write("MEM {0:.6f} {1:.4f}\n".format(proc_rss, _req_num.value))
proc_fds = _proc.num_fds()
_fd_file.write("MEM {0:.6f} {1:.4f}\n".format(proc_fds, _req_num.value))
proc_conns = _proc.connections()
_conn_file.write("MEM {0:.6f} {1:.4f}\n".format(len(proc_conns), _req_num.value))
_mem_file.flush()
_fd_file.flush()
_conn_file.flush()
time.sleep(1)
def do_one_req():
global _req_num
try:
if CLIENT_TYPE == ClientType.REQUESTS:
response = requests.head(URL, verify=VERIFY, allow_redirects=False, timeout=1)
for r in response.history:
r.close()
response.close()
return response
elif CLIENT_TYPE == ClientType.RAW:
parts = urlparse(URL)
context = ssl.create_default_context() # close??
conn = context.wrap_socket(socket.socket(socket.AF_INET), server_hostname=parts.hostname)
conn.connect((parts.hostname, parts.port or 443))
conn.sendall("HEAD {} HTTP/1.1\r\nHost: {}\r\nAccept: */*\r\nConnection: keep-alive\r\nAccept-Encoding: gzip, deflate\r\nUser-Agent: python\r\n\r\n".format(parts.path or '/', parts.hostname).encode('utf-8'))
data = conn.recv(10000)
conn.close() # conn.shutdown(socket.SHUT_RDWR) vs close() seems to make no difference
elif CLIENT_TYPE == ClientType.AIOHTTP:
loop = asyncio.get_event_loop()
async def doit():
async with aiohttp.ClientSession() as session:
async with session.head(URL, allow_redirects=False) as response:
pass
loop.run_until_complete(doit())
else:
assert False
except:
traceback.print_exc()
finally:
_req_num.value += 1
def do_reqs():
if CLIENT_TYPE == ClientType.AIOHTTP:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
while True:
do_one_req()
print('.', end='', flush=True)
_TRACE_FILTERS = (
tracemalloc.Filter(False, "<frozen importlib._bootstrap>", all_frames=True),
tracemalloc.Filter(False, tracemalloc.__file__, all_frames=True),
tracemalloc.Filter(False, linecache.__file__, all_frames=True)
)
def tracemalloc_checker():
_first_snap = None
LEAK_DUMP_PATH = '/tmp/mem_log.txt'
if os.path.exists(LEAK_DUMP_PATH):
os.unlink(LEAK_DUMP_PATH)
while True:
time.sleep(60)
start = time.time()
if _first_snap is None:
gc.collect()
_first_snap = tracemalloc.take_snapshot().filter_traces(_TRACE_FILTERS)
continue
# print memory diff
gc.collect()
snap = tracemalloc.take_snapshot().filter_traces(_TRACE_FILTERS)
top_stats = snap.compare_to(_first_snap, 'traceback')
top_stats = sorted(top_stats, key=lambda x: x.size, reverse=True)
with open(LEAK_DUMP_PATH, 'w+') as f:
f = sys.stdout
f.write('===============================' + os.linesep)
f.write("[Top 20 differences elapsed: {}]".format(round(time.time() - start)) + os.linesep)
for stat in top_stats[:20]:
f.write(str(stat) + os.linesep)
for line in stat.traceback.format():
f.write('\t' + str(line) + os.linesep)
f.write('===============================' + os.linesep)
def periodic_gc():
while True:
time.sleep(5)
gc.collect()
# print(get_max_rss())
def wait_server_up():
while True:
try:
do_one_req()
return
except requests.exceptions.ConnectionError:
pass
if SERVER_TYPE == ServerType.LOCAL_HTTPS:
def ssl_server():
import http.server, ssl
setproctitle.setproctitle('server')
class SimpleHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
protocol_version = "HTTP/1.1"
def do_HEAD(self):
self.send_response_only(http.HTTPStatus.MOVED_PERMANENTLY, "Moved Permanently")
self.send_header("Content-Length", 0)
self.send_header("Location", URL)
self.send_header("Connection", "close")
self.end_headers()
server_address = (HOSTNAME, SSL_PORT)
httpd = http.server.HTTPServer(server_address, SimpleHTTPRequestHandler)
httpd.socket = ssl.wrap_socket(httpd.socket,
server_side=True,
keyfile=SERVER_KEY_PATH,
certfile=SERVER_CERT_PATH,
ssl_version=ssl.PROTOCOL_TLSv1)
httpd.serve_forever()
def create_certs():
from calib.manager import CertificateManager # you'll need this fix for py3: git+git://github.com/thehesiod/pyca.git@fix-py3#egg=calib
# Create certs
if not os.path.exists(CERT_ROOT):
os.mkdir(CERT_ROOT)
mgr = CertificateManager(CERT_ROOT)
mgr.init([HOSTNAME])
print("Creating Root Cert, enter {} for commonName".format(HOSTNAME))
mgr.createRootCertificate(noPass=True, keyLength=2048)
print("Creating Server Cert, enter {} for commonName".format(HOSTNAME))
mgr.createServerCertificate(HOSTNAME)
with open(COMBINED_CERT_PATH, 'w') as proxy_f:
for fp in COMBINED_CERT_FILES:
with open(fp, 'r') as f:
proxy_f.write(f.read() + os.linesep)
def main():
logging.basicConfig(level=logging.WARNING)
setproctitle.setproctitle('leaker')
threads = [threading.Thread(target=periodic_gc, daemon=True)]
if SERVER_TYPE in {ServerType.LOCAL_HTTPS}:
create_certs()
server_proc = multiprocessing.Process(target=ssl_server)
server_proc.start()
wait_server_up()
mem_proc = multiprocessing.Process(target=memleak_checker, args=[os.getpid()])
mem_proc.start()
# leak_thread = threading.Thread(target=tracemalloc_checker, daemon=True)
# tracemalloc.start(15)
# leak_thread.start()
threads += [threading.Thread(target=do_reqs, daemon=True) for _ in range(num_fetchers)]
for t in threads:
if not t.is_alive():
t.start()
for t in threads:
t.join()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment