Last active
August 29, 2015 14:22
-
-
Save gergelypolonkai/498a32297f39b4960ad7 to your computer and use it in GitHub Desktop.
@ParamConverter à la Django
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import re | |
from django.shortcuts import get_object_or_404 | |
from django.db import models | |
def convert_params(*params_to_convert, **options): | |
""" | |
Convert parameters to objects. Each parameter to this decorator | |
must be a model instance (subclass of django.db.models.Model) or a | |
tuple with the following members: | |
* model: a Model subclass | |
* param_name: the name of the parameter that holds the value to be | |
matched. If not exists, or is None, the model’s class name will | |
be converted from ModelName to model_name form, suffixed with | |
"_id". E.g. for MyModel, the default will be my_model_id | |
* the field name against which the value in param_name will be | |
matched. If not exists or is None, the default will be "id" | |
* obj_param_name: the name of the parameter that will hold the | |
resolved object. If not exists or None, the default value will | |
be the model’s class name converted from ModelName to model_name | |
form, e.g. for MyModel, the default value will be my_model. | |
The values are resolved with get_object_or_404, so if the given | |
object doesn’t exist, it will redirect to a 404 page. If you want | |
to allow non-existing models, pass prevent_404=True as a keyword | |
argument. | |
""" | |
prevent_404 = options.pop('prevent_404', False) | |
def is_model(m): | |
return issubclass(type(m), models.base.ModelBase) | |
if len(params_to_convert) == 0: | |
raise ValueError("Must pass at least one parameter spec!") | |
if ( | |
len(params_to_convert) == 1 and \ | |
hasattr(params_to_convert[0], '__call__') and \ | |
not is_model(params_to_convert[0])): | |
raise ValueError("This decorator must have arguments!") | |
def convert_params_decorator(func): | |
def wrapper(*args, **kwargs): | |
converted_params = () | |
for pspec in params_to_convert: | |
# If the current pspec is not a tuple, let’s assume | |
# it’s a model class | |
if not isinstance(pspec, tuple): | |
pspec = (pspec,) | |
# First, and the only required element in the | |
# parameters is the model name which this object | |
# belongs to | |
model = pspec[0] | |
if not is_model(model): | |
raise ValueError( | |
"First value in pspec must be a Model subclass!") | |
# We will calculate these soon… | |
param_name = None | |
calc_obj_name = re.sub( | |
'([a-z0-9])([A-Z])', | |
r'\1_\2', | |
re.sub( | |
'(.)([A-Z][a-z]+)', | |
r'\1_\2', | |
model.__name__)).lower() | |
obj_field_name = None | |
# The second element, if not None, is the keyword | |
# parameter name that holds the value to convert | |
if len(pspec) < 2 or pspec[1] is None: | |
param_name = calc_obj_name + '_id' | |
else: | |
param_name = pspec[1] | |
if param_name in converted_params: | |
raise ValueError('%s is already converted' % param_name) | |
converted_params += (param_name,) | |
field_value = kwargs.pop(param_name) | |
# The third element is the field name which must be | |
# equal to the specified value. If it doesn’t exist or | |
# None, it defaults to 'id' | |
if (len(pspec) < 3) or pspec[2] is None: | |
obj_field_name = 'id' | |
else: | |
obj_field_name = pspec[2] | |
# The fourth element is the parameter name for the | |
# object. If the parameter already exists, we consider | |
# it an error | |
if (len(pspec) < 4) or pspec[3] is None: | |
obj_param_name = calc_obj_name | |
else: | |
obj_param_name = pspec[3] | |
if obj_param_name in kwargs: | |
raise KeyError( | |
"'%s' already exists as a parameter" % obj_param_name) | |
filter_kwargs = {obj_field_name: field_value} | |
if (prevent_404): | |
kwargs[obj_param_name] = model.objects.filter( | |
**filter_kwargs).first() | |
else: | |
kwargs[obj_param_name] = get_object_or_404( | |
model, | |
**filter_kwargs) | |
return func(*args, **kwargs) | |
return wrapper | |
return convert_params_decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.test import TestCase | |
from django.contrib.auth.models import User | |
from django.http import Http404 | |
from .helper import convert_params | |
class HelperTest(TestCase): | |
def setUp(self): | |
self.user = User.objects.create_user(username='test', password='test') | |
def test_convert_params_decor(self): | |
with self.assertRaises(ValueError): | |
@convert_params | |
def bad_func_1(): | |
pass # pragma: nocover | |
with self.assertRaises(ValueError): | |
@convert_params() | |
def bad_func_2(): | |
pass # pragma: nocover | |
with self.assertRaises(ValueError): | |
@convert_params('aoeu') | |
def bad_func_3(): | |
pass # pragma: nocover | |
bad_func_3() | |
with self.assertRaises(TypeError): | |
@convert_params(User) | |
def bad_func_4(user_id): | |
pass # pragma: nocover | |
bad_func_4(user_id=1) | |
@convert_params(User) | |
def good_func_1(user): | |
return user | |
with self.assertRaises(KeyError): | |
good_func_1() | |
with self.assertRaises(KeyError): | |
good_func_1(1) | |
self.assertEquals(self.user, good_func_1(user_id=self.user.id)) | |
with self.assertRaises(ValueError): | |
@convert_params(User, User) | |
def bad_func_5(): | |
pass # pragma: nocover | |
bad_func_5(user_id=1) | |
with self.assertRaises(KeyError): | |
@convert_params(User, (User, 'user_id_2', None, 'user')) | |
def bad_func_5(): | |
pass # pragma: nocover | |
bad_func_5(user_id=self.user.id, user_id_2=1) | |
@convert_params((User, 'user_id_2', None, 'my_user')) | |
def good_func_2(my_user): | |
return my_user | |
self.assertEquals(self.user, good_func_2(user_id_2=self.user.id)) | |
@convert_params((User, 'username', 'username')) | |
def good_func_3(user): | |
return user | |
self.assertEquals(self.user, good_func_3(username=self.user.username)) | |
with self.assertRaises(Http404): | |
good_func_3(username='badusername') | |
@convert_params(User, prevent_404=True) | |
def good_func_4(user): | |
return user | |
self.assertEquals(self.user, good_func_4(user_id=self.user.id)) | |
self.assertIsNone(good_func_4(user_id=9999)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment