Skip to content

Instantly share code, notes, and snippets.

@beeftornado
Last active April 21, 2020 22:55
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save beeftornado/7d1ded6989ba5d0d462b to your computer and use it in GitHub Desktop.
Save beeftornado/7d1ded6989ba5d0d462b to your computer and use it in GitHub Desktop.
Python simulate broken dns for unit tests
# # Broken Socket
""" A broken socket implementation.
Credits:
* [A. Jesse Jiryu Davis](http://emptysqua.re/blog/undoing-gevents-monkey-patching/)
Monkey patches the built in socket implmentation so that various exceptions are
raised. Useful for running unit tests to validate behavior when connections
fail.
**Usage**:
If you have a directory structure like this for your tests:
.. sourcecode::
src/
..project files..
tests/
__init__.py
test_pkg.py
utils/
__init__.py
broken_socket.py (this file)
Then in the test code you can use the broken socket implementation like this:
.. sourcecode:: python
from nose.tools import timed
from .utils import broken_socket
# new-style
class TestMyApp(object):
def __init__(self):
super(TestMyApp, self).__init__()
@timed(10)
def test_third_party_api_call(self):
old_dns_attrs = broken_socket.patch_dns()
try:
# do test ...
finally:
broken_socket.unpatch_dns(old_dns_attrs)
# old-style
class Test(unittest.TestCase):
def test(self):
old_dns_attrs = broken_socket.patch_dns()
try:
# do test ...
finally:
broken_socket.unpatch_dns(old_dns_attrs)
"""
def patch_dns():
""" Patches the socket module to create broken connections.
This method stores the old socket attributes for unpatching.
"""
_socket = __import__('socket')
old_attrs = {}
old_attrs['getaddrinfo'] = _socket.getaddrinfo
_socket.getaddrinfo = getaddrinfo
old_attrs['gethostbyname'] = _socket.gethostbyname
_socket.gethostbyname = gethostbyname
return old_attrs
def unpatch_dns(old_attrs):
""" Take output of patch_dns() and undo patching. """
_socket = __import__('socket')
for attr in old_attrs:
setattr(_socket, attr, old_attrs[attr])
def raise_random_socket_error():
""" Raises a random socket error """
import socket
import random
errs = [
socket.error,
socket.gaierror,
socket.timeout,
socket.herror,
]
raise random.choice(errs)()
def gethostbyname(*args, **kwargs):
""" Broken version """
raise_random_socket_error()
def getaddrinfo(*args, **kwargs):
""" Broken version """
raise_random_socket_error()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment