Skip to content

Instantly share code, notes, and snippets.

@yardensachs
Last active October 8, 2018 10:55
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save yardensachs/62b563a41e2761473e6c857d348eca6e to your computer and use it in GitHub Desktop.
Save yardensachs/62b563a41e2761473e6c857d348eca6e to your computer and use it in GitHub Desktop.
Django TestCase assert method for counting executed queries with a specified model.

Specified model Django TestCase assertion

Django TestCase assert method for counting executed queries with a specified model.

Built using example from the django assertNumQueries method.

Usage:

with self.assertModelAndNumQueries([
                                    (<model>, <model_query_count>),
                                    ],
                                    <total_query_count>):

Example:

In this example you could use this check on an endoint like "GET /api/blogs/1/posts/1". So the assert will make sure that User and Post models were each queried once and that all in all there were only 2 queries executed (including the model queries and any other queries).

with self.assertModelAndNumQueries([(User, 1), (Post, 1)], 2):
    Client.get('/api/blogs/1/posts/1')
"""
Gist assert_queries_model_and_num.py
https://gist.github.com/yardensachs/62b563a41e2761473e6c857d348eca6e
Django TestCase assert method for counting executed queries with a specified model.
Built using example from the django assertNumQueries method.
Usage:
with self.assertModelAndNumQueries([
(<model>, <model_query_count>),
],
<total_query_count>):
Example:
# In this example you could use this check on an endoint like "GET /api/blogs/1/posts/1".
# So the assert will make sure that User and Post models were each queried once and that
# all in all there were only 2 queries executed (including the model queries and any other queries).
with self.assertModelAndNumQueries([(User, 1), (Post, 1)], 2):
Client.get('/api/blogs/1/posts/1')
"""
import re
import sqlparse
from django.db import DEFAULT_DB_ALIAS, connections
from django.test import TestCase
from django.test.utils import CaptureQueriesContext
class _AssertQueriesModelAndNumContext(CaptureQueriesContext):
TABLE_NAME_PATTERN = re.compile(r'FROM ["\'\`]?([a-zA-Z\d\_]+)["\'\`]?')
def __init__(self, test_case, models_and_num, total_num, connection):
self.test_case = test_case
self.total_num = total_num
self.models_and_num = models_and_num
super().__init__(connection)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if exc_type is not None:
return
if self.total_num is not None:
executed = len(self)
self.test_case.assertEqual(
executed, self.total_num,
"%d queries executed, %d expected\nCaptured queries were:\n%s" % (
executed, self.total_num,
'\n\n'.join(
'%d. %s' % (i, sqlparse.format(query['sql'], reindent=True, keyword_case='upper')) for i, query in enumerate(self.captured_queries, start=1)
)
)
)
model_table_count = {model: 0 for model, num in self.models_and_num}
for query in self.captured_queries:
table_name = self.TABLE_NAME_PATTERN.search(query['sql']).groups()[0].strip()
for model, num in self.models_and_num:
if table_name == model._meta.db_table:
model_table_count[model] += 1
mismatch_msg = []
for model, expected_num in self.models_and_num:
actual_num = model_table_count[model]
if expected_num != actual_num:
mismatch_msg.append('{} queries executed, {} expected using {} model'.format(actual_num, expected_num, model.__name__))
if mismatch_msg:
mismatch_msg = '\n'.join(mismatch_msg)
mismatch_msg += '\n\n'
mismatch_msg += '\n\n'.join(
'%d. %s' % (i, sqlparse.format(query['sql'], reindent=True, keyword_case='upper')) for i, query in enumerate(self.captured_queries, start=1)
)
raise self.test_case.failureException(mismatch_msg)
class SubTestCase(TestCase):
def assertModelAndNumQueries(self, models_and_num, total_num=None, func=None, *args, using=DEFAULT_DB_ALIAS, **kwargs):
conn = connections[using]
context = _AssertQueriesModelAndNumContext(self, models_and_num, total_num, conn)
if func is None:
return context
with context:
func(*args, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment