Created
November 15, 2017 16:16
-
-
Save jesseops/d7214c13861b762f856c9232b5b63482 to your computer and use it in GitHub Desktop.
Global requests/socket timeouts + retries
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Monkey-Patch Requests for automatic timeouts and retries | |
Retries for rrDNS, timeouts because you should always do timeouts. | |
If you want these features, just do ```from $MYPACKAGE.requests import requests```. | |
""""" | |
import requests | |
DEFAULT_REQUEST_TIMEOUT = 5 | |
DEFAULT_REQUEST_RETRIES = 2 | |
def _enable_retries(max_retries): | |
""" | |
Decorator for `requests.adapters.HTTPAdapter.__init__`. | |
Enables retries in the requests module. | |
Due to behavior I found when trying to simply set the default | |
argument for retries in the HTTPAdapter.__init__ method (hitting | |
max recursion depth), we simply wrap the __init__ method & | |
manually set the Retry object on the HTTPAdapter class. | |
""" | |
def outer(adapter_init): | |
retry_obj = requests.packages.urllib3.util.retry.Retry(total=max_retries, backoff_factor=0.1) | |
def inner(self, *args, **kwargs): | |
adapter_init(self, *args, **kwargs) | |
# Self is the adapter instance passed in above | |
self.max_retries = retry_obj | |
return inner | |
return outer | |
def set_request_timeout(timeout): | |
"""Set global socket timeout & requests request timeout""" | |
# Set default timeout (global and in request package) | |
import socket # keep reference local | |
from functools import partialmethod | |
socket.setdefaulttimeout(timeout) | |
requests.sessions.Session.request = partialmethod(requests.sessions.Session.request, timeout=timeout) | |
def set_request_retries(retries): | |
"""Set max num retries for requests (the package)""" | |
# Get func with reference to max_retries `Retry` object | |
enabler = _enable_retries(retries) | |
# Wrap __init__ method of HTTPAdapter | |
new_init = enabler(requests.adapters.HTTPAdapter.__init__) | |
# Apply wrapper __init__ back to HTTPAdapter | |
requests.adapters.HTTPAdapter.__init__ = new_init | |
# I like to automatically set these on import so I don't forget | |
# If you want more control just don't include these and/or set | |
# them manually when your code starts | |
set_request_timeout(DEFAULT_REQUEST_TIMEOUT) | |
set_request_retries(DEFAULT_REQUEST_RETRIES) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest # todo: swap these over to py.test | |
from time import time | |
from .requests import requests, DEFAULT_REQUEST_TIMEOUT, DEFAULT_REQUEST_RETRIES | |
class test_Requests(unittest.TestCase): | |
def test_001_set_request_timeout(self): | |
from .requests import set_request_timeout | |
self.assertEqual(requests.Session().request.keywords['timeout'], DEFAULT_REQUEST_TIMEOUT) | |
set_request_timeout(1) | |
self.assertEqual(requests.Session().request.keywords['timeout'], 1) | |
def test_002_set_request_retries(self): | |
from .requests import set_request_retries | |
for adapter in requests.Session().adapters.values(): | |
assert adapter.max_retries.total == DEFAULT_REQUEST_RETRIES | |
set_request_retries(0) | |
for adapter in requests.Session().adapters.values(): | |
assert adapter.max_retries.total == 0 | |
def test_003_timeout(self): | |
from .requests import requests, set_request_timeout, set_request_retries | |
set_request_retries(0) | |
set_request_timeout(1) | |
start = time() | |
with self.assertRaises(requests.ConnectionError): | |
requests.get('https://httpbin.org/delay/5') | |
self.assertLess((time() - start), 2) | |
start = time() | |
requests.get('https://httpbin.org/delay/2', timeout=4) | |
elapsed = time() - start | |
self.assertLess(elapsed, 4) | |
self.assertGreater(elapsed, 2) | |
def test_004_retries(self): | |
from .requests import requests, set_request_retries, set_request_timeout | |
set_request_timeout(0.5) | |
set_request_retries(2) | |
start = time() | |
with self.assertRaises(requests.ConnectionError): | |
requests.get('https://httpbin.org/delay/5') | |
self.assertLess((time() - start), 3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment