Skip to content

Instantly share code, notes, and snippets.

@bimusiek
Last active January 4, 2018 15:53
Show Gist options
  • Save bimusiek/d5f2a78979ea2a4fbe93558acbbff8ab to your computer and use it in GitHub Desktop.
Save bimusiek/d5f2a78979ea2a4fbe93558acbbff8ab to your computer and use it in GitHub Desktop.
Race conditions Django Test Case
# encoding: utf-8
from __future__ import absolute_import, unicode_literals
import logging
import os
import sys
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import connections
from django.test import TransactionTestCase
from django.test.testcases import _StaticFilesHandler, LiveServerThread
from django.utils import six
from django.utils.decorators import classproperty
log = logging.getLogger(__name__)
class RaceConditionLiveServerTestCase(TransactionTestCase):
static_handler = _StaticFilesHandler
@classproperty
def live_server_urls(cls):
return ['http://%s:%s' % (
st.host, st.port) for st in cls.server_threads]
@classmethod
def setUpClass(cls):
# Changing settings is only for internal purposes... can be deleted
cls._settings_debug = settings.DEBUG
cls._settings_caches = settings.CACHES
settings.DEBUG = True
settings.CACHES = {
'default': {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": "redis://redis:6379/0",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
"SOCKET_CONNECT_TIMEOUT": 5, # in seconds
"SOCKET_TIMEOUT": 5, # in seconds
"IGNORE_EXCEPTIONS": False,
},
"KEY_PREFIX": "tests_search"
}
}
cls.server_threads = []
super(RaceConditionLiveServerTestCase, cls).setUpClass()
connections_override = {}
for conn in connections.all():
# If using in-memory sqlite databases, pass the connections to
# the server thread.
if conn.vendor == 'sqlite' and conn.is_in_memory_db(conn.settings_dict['NAME']):
# Explicitly enable thread-shareability for this connection
conn.allow_thread_sharing = True
connections_override[conn.alias] = conn
# Launch the live server's thread
specified_addresses = os.environ.get(
'RACECONDITIONS_LIVE_TEST_SERVER_ADDRESSES', [
'localhost:8081-8179',
'localhost:9081-9179',
'localhost:10081-10179',
'localhost:11081-11179',
'localhost:12081-12179',
'localhost:13081-13179',
])
for specified_address in specified_addresses:
# The specified ports may be of the form '8000-8010,8080,9200-9300'
# i.e. a comma-separated list of ports or ranges of ports, so we break
# it down into a detailed list of all possible ports.
possible_ports = []
try:
host, port_ranges = specified_address.split(':')
for port_range in port_ranges.split(','):
# A port range can be of either form: '8000' or '8000-8010'.
extremes = list(map(int, port_range.split('-')))
assert len(extremes) in [1, 2]
if len(extremes) == 1:
# Port range of the form '8000'
possible_ports.append(extremes[0])
else:
# Port range of the form '8000-8010'
for port in range(extremes[0], extremes[1] + 1):
possible_ports.append(port)
except Exception:
msg = 'Invalid address ("%s") for live server.' % specified_address
six.reraise(ImproperlyConfigured, ImproperlyConfigured(msg), sys.exc_info()[2])
server_thread = cls._create_server_thread(host, possible_ports, connections_override)
server_thread.daemon = True
server_thread.start()
cls.server_threads.append(server_thread)
# Wait for the live server to be ready
server_thread.is_ready.wait()
if server_thread.error:
# Clean up behind ourselves, since tearDownClass won't get called in
# case of errors.
cls._tearDownClassInternal()
raise server_thread.error
@classmethod
def _create_server_thread(cls, host, possible_ports, connections_override):
return LiveServerThread(
host,
possible_ports,
cls.static_handler,
connections_override=connections_override,
)
@classmethod
def _tearDownClassInternal(cls):
for server_thread in cls.server_threads:
# Terminate the live server's thread
server_thread.terminate()
server_thread.join()
# Restore sqlite in-memory database connections' non-shareability
for conn in connections.all():
if conn.vendor == 'sqlite' and conn.is_in_memory_db(conn.settings_dict['NAME']):
conn.allow_thread_sharing = False
settings.DEBUG = cls._settings_debug
settings.CACHES = cls._settings_caches
@classmethod
def tearDownClass(cls):
cls._tearDownClassInternal()
super(RaceConditionLiveServerTestCase, cls).tearDownClass()
# encoding: utf-8
from __future__ import absolute_import, unicode_literals
import logging
from requests_futures.sessions import FuturesSession
from ahoy.lib.tests.racecondition_testcase import RaceConditionLiveServerTestCase
from ahoy.search.apps.users.models import UserModel
log = logging.getLogger(__name__)
class TestSearchAuthRaceCondition(RaceConditionLiveServerTestCase):
def test_it(self):
all_requests = []
for url in self.live_server_urls:
session = FuturesSession()
all_requests.append(session.get("".join([url, '/'.format(s.external_id)]), headers={
'Authorization': "JWT abc"
}))
# Wait for all now
all_results = []
for request in all_requests:
all_results.append(request.result())
log.debug("All results: %s", all_results)
self.assertGreater(len(all_results), 0)
for result in all_results:
log.debug('Result: %s', result.text)
for result in all_results:
self.assertEqual(result.status_code, 200, result.text)
self.assertEqual(UserModel.objects.count(), 1) # Race condition was producing 6 users...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment