|
""" |
|
Gist assert_queries_model_and_num.py |
|
https://gist.github.com/yardensachs/62b563a41e2761473e6c857d348eca6e |
|
|
|
Django TestCase assert method for counting executed queries with a specified model. |
|
|
|
Built using example from the django assertNumQueries method. |
|
|
|
Usage: |
|
with self.assertModelAndNumQueries([ |
|
(<model>, <model_query_count>), |
|
], |
|
<total_query_count>): |
|
|
|
Example: |
|
# In this example you could use this check on an endoint like "GET /api/blogs/1/posts/1". |
|
# So the assert will make sure that User and Post models were each queried once and that |
|
# all in all there were only 2 queries executed (including the model queries and any other queries). |
|
|
|
with self.assertModelAndNumQueries([(User, 1), (Post, 1)], 2): |
|
Client.get('/api/blogs/1/posts/1') |
|
""" |
|
|
|
import re |
|
|
|
import sqlparse |
|
from django.db import DEFAULT_DB_ALIAS, connections |
|
from django.test import TestCase |
|
from django.test.utils import CaptureQueriesContext |
|
|
|
|
|
class _AssertQueriesModelAndNumContext(CaptureQueriesContext): |
|
TABLE_NAME_PATTERN = re.compile(r'FROM ["\'\`]?([a-zA-Z\d\_]+)["\'\`]?') |
|
|
|
def __init__(self, test_case, models_and_num, total_num, connection): |
|
self.test_case = test_case |
|
self.total_num = total_num |
|
self.models_and_num = models_and_num |
|
super().__init__(connection) |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
super().__exit__(exc_type, exc_value, traceback) |
|
if exc_type is not None: |
|
return |
|
|
|
if self.total_num is not None: |
|
executed = len(self) |
|
|
|
self.test_case.assertEqual( |
|
executed, self.total_num, |
|
"%d queries executed, %d expected\nCaptured queries were:\n%s" % ( |
|
executed, self.total_num, |
|
'\n\n'.join( |
|
'%d. %s' % (i, sqlparse.format(query['sql'], reindent=True, keyword_case='upper')) for i, query in enumerate(self.captured_queries, start=1) |
|
) |
|
) |
|
) |
|
|
|
model_table_count = {model: 0 for model, num in self.models_and_num} |
|
|
|
for query in self.captured_queries: |
|
table_name = self.TABLE_NAME_PATTERN.search(query['sql']).groups()[0].strip() |
|
|
|
for model, num in self.models_and_num: |
|
if table_name == model._meta.db_table: |
|
model_table_count[model] += 1 |
|
|
|
mismatch_msg = [] |
|
|
|
for model, expected_num in self.models_and_num: |
|
actual_num = model_table_count[model] |
|
if expected_num != actual_num: |
|
mismatch_msg.append('{} queries executed, {} expected using {} model'.format(actual_num, expected_num, model.__name__)) |
|
|
|
if mismatch_msg: |
|
mismatch_msg = '\n'.join(mismatch_msg) |
|
mismatch_msg += '\n\n' |
|
mismatch_msg += '\n\n'.join( |
|
'%d. %s' % (i, sqlparse.format(query['sql'], reindent=True, keyword_case='upper')) for i, query in enumerate(self.captured_queries, start=1) |
|
) |
|
|
|
raise self.test_case.failureException(mismatch_msg) |
|
|
|
|
|
class SubTestCase(TestCase): |
|
|
|
def assertModelAndNumQueries(self, models_and_num, total_num=None, func=None, *args, using=DEFAULT_DB_ALIAS, **kwargs): |
|
conn = connections[using] |
|
|
|
context = _AssertQueriesModelAndNumContext(self, models_and_num, total_num, conn) |
|
if func is None: |
|
return context |
|
|
|
with context: |
|
func(*args, **kwargs) |