Skip to content

Instantly share code, notes, and snippets.

@jourdanrodrigues
Last active May 22, 2017 16:41
Show Gist options
  • Save jourdanrodrigues/61c6cac5a4a0b7dab4a2bf20908390d0 to your computer and use it in GitHub Desktop.
Save jourdanrodrigues/61c6cac5a4a0b7dab4a2bf20908390d0 to your computer and use it in GitHub Desktop.
Classes for better testing w/ Django 1.11 & DRF 3.6.
from copy import deepcopy
from django.contrib.admin import AdminSite
from django.test import TestCase
from django.utils.translation import ugettext_lazy as _
from rest_framework import status
from rest_framework.response import Response
from rest_framework.test import APITestCase
def _get_not_implemented_message(attribute, class_name):
"""
:type attribute: str
:type class_name: str
:rtype: str
"""
return _('Must set the "{}" attribute on "{}"'.format(attribute, class_name))
class MockRequest(object):
pass
class MockSuperUser(object):
@staticmethod
def has_perm(perm):
return perm or True
class AdminTestCase(TestCase):
"""
To run the test cases: # Not a good solution, but I'm sure soon will figure out
def test_case(self):
self.run_test_case()
"""
mock_request = MockRequest()
mock_super_user = MockSuperUser()
admin_excludes = None
@property
def model(self):
raise NotImplementedError(_get_not_implemented_message('model', self.__class__.__name__))
@property
def form_fields(self):
raise NotImplementedError(_get_not_implemented_message('form_fields', self.__class__.__name__))
@property
def admin_fields(self):
raise NotImplementedError(_get_not_implemented_message('admin_fields', self.__class__.__name__))
@property
def admin_fieldsets(self):
raise NotImplementedError(_get_not_implemented_message('admin_fieldsets', self.__class__.__name__))
@property
def admin_class(self):
raise NotImplementedError(_get_not_implemented_message('admin_class', self.__class__.__name__))
def setUp(self):
self.object = self.model.objects.first()
self.site = AdminSite()
def run_test_fields(self):
admin = self.admin_class(self.model, self.site)
self.assertEqual(list(admin.get_form(self.mock_request).base_fields), self.form_fields)
self.assertEqual(list(admin.get_fields(self.mock_request)), self.admin_fields)
self.assertEqual(list(admin.get_fields(self.mock_request, self.object)), self.admin_fields)
self.assertEqual(admin.get_exclude(self.mock_request, self.object), self.admin_excludes)
def run_test_fieldsets(self):
admin = self.admin_class(self.model, self.site)
self.assertEqual(admin.get_fieldsets(self.mock_request), self.admin_fieldsets)
self.assertEqual(admin.get_fieldsets(self.mock_request, self.object), self.admin_fieldsets)
class BaseAPITestCase(APITestCase):
request_kwargs = {'format': 'json'}
url_regex = r'^(https?|ftp):\/\/[^:\/\s]+(\/\w+)*\/[\w\-\.]+[^#?\s]+(.*)?(#[\w\-]+)?$'
def assertUnauthorizedResponse(self, response, msg=None):
"""
:type response: Response
"""
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED, msg)
self.assertDictEqual(response.data, {'detail': _('Authentication credentials were not provided.')}, msg)
def assertForbiddenResponse(self, response, msg=None):
"""
:type response: Response
"""
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN, msg)
self.assertDictEqual(response.data, {'detail': _('You do not have permission to perform this action.')}, msg)
def assertNotFoundResponse(self, response, msg=None):
"""
:type response: Response
"""
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND, msg)
self.assertDictEqual(response.data, {'detail': _('Not found.')}, msg)
def bulkAssertIn(self, items, data, is_list=False, list_length=None, nullable=False):
"""
:param items: list of keys that must be in the "data"
:type items: list of str | list of dict
:param data: dictionary (or a list of dictionaries) to be tested
:type data: dict | list of dict
:param is_list: Applies additional tests for the list if True
:type is_list: bool
:param list_length: Defines a number of items to be in the list
:type list_length: int
:param nullable: Defines if the data sent can be null
:type nullable: bool
"""
_data = deepcopy(data)
if is_list:
self.assertIsInstance(_data, list)
if list_length is None:
# Get the first item and check its items
_data = _data[0]
elif list_length > 0:
self.assertEqual(len(_data), list_length)
_data = _data[0]
else:
return
if nullable:
try:
self.assertIsNone(_data)
return
except AssertionError:
pass
def which_params(missing, containing):
def mount_list(attributes, its_list):
for attribute in attributes:
if isinstance(attribute, dict):
its_list += list(attribute.keys())
else:
its_list.append(attribute)
missing_attributes = []
containing_attributes = []
mount_list(missing, missing_attributes)
mount_list(containing, containing_attributes)
return ', '.join([x for x in missing_attributes if x not in containing_attributes])
def bulk_test(entries, target):
len_entries = 0
# Check each entry
for entry in entries:
# If entry is a dict, target has sub items
if isinstance(entry, dict):
for key, sub_entries in entry.items():
len_entries += 1
if isinstance(sub_entries, dict): # Has a configuration
if sub_entries.get('nullable'):
try:
self.assertIsNone(target[key])
continue
except AssertionError:
pass
if sub_entries.get('is_list'):
self.assertIsInstance(target[key], list)
if len(target[key]):
target[key] = target[key][0]
else:
raise AssertionError('"{}" is a list and cannot be empty'.format(key))
if 'entries' in sub_entries:
sub_entries = sub_entries.get('entries') # List of entries to test
elif 'value' in sub_entries:
self.assertEqual(
sub_entries['value'], target[key], # Test specific value
msg=u'"{key}": "{0} != {1}"'.format(
sub_entries['value'], target[key], key=key
)
)
continue
elif 'regex' in sub_entries:
self.assertRegex(
str(target[key]), sub_entries['regex'],
msg=u'"{key}": "{1} did not match `{0}`"'.format(
sub_entries['regex'], target[key], key=key
)
)
continue
else:
raise AssertionError('"{}" missing "entries", "regex" or "value" key.'.format(key))
self.assertIn(key, target, msg=u'"{}" key missing in "{}".'.format(key, target))
bulk_test(sub_entries, target[key])
else:
len_entries += 1
self.assertIn(entry, target)
# Check if these are the only attributes in the dictionary
self.assertEqual(
len_entries, len(target),
msg='{entries} != {target}: Missing keys: {keys}.'.format(
entries=len_entries, target=len(target),
keys=(which_params(entries, target)
if len_entries > len(target) else
which_params(target, entries)))
)
bulk_test(items, _data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment