-
-
Save thehesiod/ef79dd77e2df7a0a7893dfea6325d30a to your computer and use it in GitHub Desktop.
requests leak checker
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing | |
import setproctitle | |
import os | |
import threading | |
import tracemalloc | |
import linecache | |
import ssl | |
import traceback | |
from urllib.parse import urlparse | |
import gc | |
import logging | |
import time | |
import psutil | |
import sys | |
import socket | |
from enum import IntEnum | |
HOSTNAME = socket.gethostname() | |
CERT_ROOT = '/tmp/certs' | |
CA_CERT_PATH = os.path.join(CERT_ROOT, 'certs/ca.cert.pem') | |
CA_KEY_PATH = os.path.join(CERT_ROOT, 'private/ca.key.pem') | |
SERVER_CERT_PATH = os.path.join(CERT_ROOT, 'certs/{}.cert.pem'.format(HOSTNAME)) | |
SERVER_KEY_PATH = os.path.join(CERT_ROOT, 'private/{}.key.pem'.format(HOSTNAME)) | |
# HA_PROXY cert ordering: server_key, server_cert, ca_cert | |
COMBINED_CERT_PATH = os.path.join(CERT_ROOT, "haproxy.pem") | |
COMBINED_CERT_FILES = [SERVER_KEY_PATH, SERVER_CERT_PATH, CA_CERT_PATH] | |
SSL_PORT = 4443 | |
class ServerType(IntEnum): | |
EXTERNAL_HTTPS = 0 | |
LOCAL_HTTPS = 2 | |
RAW_SSL_SOCKET = 3 | |
class ClientType(IntEnum): | |
REQUESTS = 0 | |
RAW = 1 | |
AIOHTTP = 2 | |
SERVER_TYPE = ServerType.EXTERNAL_HTTPS | |
CLIENT_TYPE = ClientType.RAW | |
if CLIENT_TYPE == ClientType.AIOHTTP: | |
import aiohttp, asyncio | |
elif CLIENT_TYPE == ClientType.REQUESTS: | |
import requests | |
if SERVER_TYPE == ServerType.EXTERNAL_HTTPS: | |
URL = 'https://google.com' | |
VERIFY = None # none needed | |
num_fetchers = 2 | |
elif SERVER_TYPE == ServerType.LOCAL_HTTPS: | |
URL = 'https://{}:{}/'.format(HOSTNAME, SSL_PORT) | |
VERIFY = CA_CERT_PATH | |
num_fetchers = 1 | |
else: | |
assert False | |
_req_num = multiprocessing.Value('i', 0) | |
def memleak_checker(pid): | |
global _req_num | |
setproctitle.setproctitle('leak reporter') | |
_proc = psutil.Process(pid) | |
suffix = time.strftime("%Y%m%d%H%M%S", time.localtime()) | |
mprofile_output = "mprofile_%s.dat" % suffix | |
mprofile_fd_output = "fprofile_%s.dat" % suffix | |
mprofile_net_output = "nprofile_%s.dat" % suffix | |
files = [mprofile_output, mprofile_fd_output, mprofile_net_output] | |
for fp in files: | |
if os.path.exists(fp): | |
os.unlink(fp) | |
_mem_file = open(mprofile_output, 'w') | |
_fd_file = open(mprofile_fd_output, 'w') | |
_conn_file = open(mprofile_net_output, 'w') | |
while True: | |
if _req_num.value < 30: | |
time.sleep(1) # warm up | |
# instead of calling: memory_profiler.memory_usage(-1, interval=.3, timestamps=True, timeout=100000, stream=f) | |
# we do this manually so we do a timestamp per request...also memory_usage keeps accumulating results for some reason | |
proc_rss = _proc.memory_info()[0] / float(2 ** 20) | |
_mem_file.write("MEM {0:.6f} {1:.4f}\n".format(proc_rss, _req_num.value)) | |
proc_fds = _proc.num_fds() | |
_fd_file.write("MEM {0:.6f} {1:.4f}\n".format(proc_fds, _req_num.value)) | |
proc_conns = _proc.connections() | |
_conn_file.write("MEM {0:.6f} {1:.4f}\n".format(len(proc_conns), _req_num.value)) | |
_mem_file.flush() | |
_fd_file.flush() | |
_conn_file.flush() | |
time.sleep(1) | |
def do_one_req(): | |
global _req_num | |
try: | |
if CLIENT_TYPE == ClientType.REQUESTS: | |
response = requests.head(URL, verify=VERIFY, allow_redirects=False, timeout=1) | |
for r in response.history: | |
r.close() | |
response.close() | |
return response | |
elif CLIENT_TYPE == ClientType.RAW: | |
parts = urlparse(URL) | |
context = ssl.create_default_context() # close?? | |
conn = context.wrap_socket(socket.socket(socket.AF_INET), server_hostname=parts.hostname) | |
conn.connect((parts.hostname, parts.port or 443)) | |
conn.sendall("HEAD {} HTTP/1.1\r\nHost: {}\r\nAccept: */*\r\nConnection: keep-alive\r\nAccept-Encoding: gzip, deflate\r\nUser-Agent: python\r\n\r\n".format(parts.path or '/', parts.hostname).encode('utf-8')) | |
data = conn.recv(10000) | |
conn.close() # conn.shutdown(socket.SHUT_RDWR) vs close() seems to make no difference | |
elif CLIENT_TYPE == ClientType.AIOHTTP: | |
loop = asyncio.get_event_loop() | |
async def doit(): | |
async with aiohttp.ClientSession() as session: | |
async with session.head(URL, allow_redirects=False) as response: | |
pass | |
loop.run_until_complete(doit()) | |
else: | |
assert False | |
except: | |
traceback.print_exc() | |
finally: | |
_req_num.value += 1 | |
def do_reqs(): | |
if CLIENT_TYPE == ClientType.AIOHTTP: | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
while True: | |
do_one_req() | |
print('.', end='', flush=True) | |
_TRACE_FILTERS = ( | |
tracemalloc.Filter(False, "<frozen importlib._bootstrap>", all_frames=True), | |
tracemalloc.Filter(False, tracemalloc.__file__, all_frames=True), | |
tracemalloc.Filter(False, linecache.__file__, all_frames=True) | |
) | |
def tracemalloc_checker(): | |
_first_snap = None | |
LEAK_DUMP_PATH = '/tmp/mem_log.txt' | |
if os.path.exists(LEAK_DUMP_PATH): | |
os.unlink(LEAK_DUMP_PATH) | |
while True: | |
time.sleep(60) | |
start = time.time() | |
if _first_snap is None: | |
gc.collect() | |
_first_snap = tracemalloc.take_snapshot().filter_traces(_TRACE_FILTERS) | |
continue | |
# print memory diff | |
gc.collect() | |
snap = tracemalloc.take_snapshot().filter_traces(_TRACE_FILTERS) | |
top_stats = snap.compare_to(_first_snap, 'traceback') | |
top_stats = sorted(top_stats, key=lambda x: x.size, reverse=True) | |
with open(LEAK_DUMP_PATH, 'w+') as f: | |
f = sys.stdout | |
f.write('===============================' + os.linesep) | |
f.write("[Top 20 differences elapsed: {}]".format(round(time.time() - start)) + os.linesep) | |
for stat in top_stats[:20]: | |
f.write(str(stat) + os.linesep) | |
for line in stat.traceback.format(): | |
f.write('\t' + str(line) + os.linesep) | |
f.write('===============================' + os.linesep) | |
def periodic_gc(): | |
while True: | |
time.sleep(5) | |
gc.collect() | |
# print(get_max_rss()) | |
def wait_server_up(): | |
while True: | |
try: | |
do_one_req() | |
return | |
except requests.exceptions.ConnectionError: | |
pass | |
if SERVER_TYPE == ServerType.LOCAL_HTTPS: | |
def ssl_server(): | |
import http.server, ssl | |
setproctitle.setproctitle('server') | |
class SimpleHTTPRequestHandler(http.server.BaseHTTPRequestHandler): | |
protocol_version = "HTTP/1.1" | |
def do_HEAD(self): | |
self.send_response_only(http.HTTPStatus.MOVED_PERMANENTLY, "Moved Permanently") | |
self.send_header("Content-Length", 0) | |
self.send_header("Location", URL) | |
self.send_header("Connection", "close") | |
self.end_headers() | |
server_address = (HOSTNAME, SSL_PORT) | |
httpd = http.server.HTTPServer(server_address, SimpleHTTPRequestHandler) | |
httpd.socket = ssl.wrap_socket(httpd.socket, | |
server_side=True, | |
keyfile=SERVER_KEY_PATH, | |
certfile=SERVER_CERT_PATH, | |
ssl_version=ssl.PROTOCOL_TLSv1) | |
httpd.serve_forever() | |
def create_certs(): | |
from calib.manager import CertificateManager # you'll need this fix for py3: git+git://github.com/thehesiod/pyca.git@fix-py3#egg=calib | |
# Create certs | |
if not os.path.exists(CERT_ROOT): | |
os.mkdir(CERT_ROOT) | |
mgr = CertificateManager(CERT_ROOT) | |
mgr.init([HOSTNAME]) | |
print("Creating Root Cert, enter {} for commonName".format(HOSTNAME)) | |
mgr.createRootCertificate(noPass=True, keyLength=2048) | |
print("Creating Server Cert, enter {} for commonName".format(HOSTNAME)) | |
mgr.createServerCertificate(HOSTNAME) | |
with open(COMBINED_CERT_PATH, 'w') as proxy_f: | |
for fp in COMBINED_CERT_FILES: | |
with open(fp, 'r') as f: | |
proxy_f.write(f.read() + os.linesep) | |
def main(): | |
logging.basicConfig(level=logging.WARNING) | |
setproctitle.setproctitle('leaker') | |
threads = [threading.Thread(target=periodic_gc, daemon=True)] | |
if SERVER_TYPE in {ServerType.LOCAL_HTTPS}: | |
create_certs() | |
server_proc = multiprocessing.Process(target=ssl_server) | |
server_proc.start() | |
wait_server_up() | |
mem_proc = multiprocessing.Process(target=memleak_checker, args=[os.getpid()]) | |
mem_proc.start() | |
# leak_thread = threading.Thread(target=tracemalloc_checker, daemon=True) | |
# tracemalloc.start(15) | |
# leak_thread.start() | |
threads += [threading.Thread(target=do_reqs, daemon=True) for _ in range(num_fetchers)] | |
for t in threads: | |
if not t.is_alive(): | |
t.start() | |
for t in threads: | |
t.join() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment