Skip to content

Instantly share code, notes, and snippets.

@gergelypolonkai
Last active August 29, 2015 14:22
Show Gist options
  • Save gergelypolonkai/498a32297f39b4960ad7 to your computer and use it in GitHub Desktop.
Save gergelypolonkai/498a32297f39b4960ad7 to your computer and use it in GitHub Desktop.
@ParamConverter à la Django
# -*- coding: utf-8 -*-
import re
from django.shortcuts import get_object_or_404
from django.db import models
def convert_params(*params_to_convert, **options):
"""
Convert parameters to objects. Each parameter to this decorator
must be a model instance (subclass of django.db.models.Model) or a
tuple with the following members:
* model: a Model subclass
* param_name: the name of the parameter that holds the value to be
matched. If not exists, or is None, the model’s class name will
be converted from ModelName to model_name form, suffixed with
"_id". E.g. for MyModel, the default will be my_model_id
* the field name against which the value in param_name will be
matched. If not exists or is None, the default will be "id"
* obj_param_name: the name of the parameter that will hold the
resolved object. If not exists or None, the default value will
be the model’s class name converted from ModelName to model_name
form, e.g. for MyModel, the default value will be my_model.
The values are resolved with get_object_or_404, so if the given
object doesn’t exist, it will redirect to a 404 page. If you want
to allow non-existing models, pass prevent_404=True as a keyword
argument.
"""
prevent_404 = options.pop('prevent_404', False)
def is_model(m):
return issubclass(type(m), models.base.ModelBase)
if len(params_to_convert) == 0:
raise ValueError("Must pass at least one parameter spec!")
if (
len(params_to_convert) == 1 and \
hasattr(params_to_convert[0], '__call__') and \
not is_model(params_to_convert[0])):
raise ValueError("This decorator must have arguments!")
def convert_params_decorator(func):
def wrapper(*args, **kwargs):
converted_params = ()
for pspec in params_to_convert:
# If the current pspec is not a tuple, let’s assume
# it’s a model class
if not isinstance(pspec, tuple):
pspec = (pspec,)
# First, and the only required element in the
# parameters is the model name which this object
# belongs to
model = pspec[0]
if not is_model(model):
raise ValueError(
"First value in pspec must be a Model subclass!")
# We will calculate these soon…
param_name = None
calc_obj_name = re.sub(
'([a-z0-9])([A-Z])',
r'\1_\2',
re.sub(
'(.)([A-Z][a-z]+)',
r'\1_\2',
model.__name__)).lower()
obj_field_name = None
# The second element, if not None, is the keyword
# parameter name that holds the value to convert
if len(pspec) < 2 or pspec[1] is None:
param_name = calc_obj_name + '_id'
else:
param_name = pspec[1]
if param_name in converted_params:
raise ValueError('%s is already converted' % param_name)
converted_params += (param_name,)
field_value = kwargs.pop(param_name)
# The third element is the field name which must be
# equal to the specified value. If it doesn’t exist or
# None, it defaults to 'id'
if (len(pspec) < 3) or pspec[2] is None:
obj_field_name = 'id'
else:
obj_field_name = pspec[2]
# The fourth element is the parameter name for the
# object. If the parameter already exists, we consider
# it an error
if (len(pspec) < 4) or pspec[3] is None:
obj_param_name = calc_obj_name
else:
obj_param_name = pspec[3]
if obj_param_name in kwargs:
raise KeyError(
"'%s' already exists as a parameter" % obj_param_name)
filter_kwargs = {obj_field_name: field_value}
if (prevent_404):
kwargs[obj_param_name] = model.objects.filter(
**filter_kwargs).first()
else:
kwargs[obj_param_name] = get_object_or_404(
model,
**filter_kwargs)
return func(*args, **kwargs)
return wrapper
return convert_params_decorator
from django.test import TestCase
from django.contrib.auth.models import User
from django.http import Http404
from .helper import convert_params
class HelperTest(TestCase):
def setUp(self):
self.user = User.objects.create_user(username='test', password='test')
def test_convert_params_decor(self):
with self.assertRaises(ValueError):
@convert_params
def bad_func_1():
pass # pragma: nocover
with self.assertRaises(ValueError):
@convert_params()
def bad_func_2():
pass # pragma: nocover
with self.assertRaises(ValueError):
@convert_params('aoeu')
def bad_func_3():
pass # pragma: nocover
bad_func_3()
with self.assertRaises(TypeError):
@convert_params(User)
def bad_func_4(user_id):
pass # pragma: nocover
bad_func_4(user_id=1)
@convert_params(User)
def good_func_1(user):
return user
with self.assertRaises(KeyError):
good_func_1()
with self.assertRaises(KeyError):
good_func_1(1)
self.assertEquals(self.user, good_func_1(user_id=self.user.id))
with self.assertRaises(ValueError):
@convert_params(User, User)
def bad_func_5():
pass # pragma: nocover
bad_func_5(user_id=1)
with self.assertRaises(KeyError):
@convert_params(User, (User, 'user_id_2', None, 'user'))
def bad_func_5():
pass # pragma: nocover
bad_func_5(user_id=self.user.id, user_id_2=1)
@convert_params((User, 'user_id_2', None, 'my_user'))
def good_func_2(my_user):
return my_user
self.assertEquals(self.user, good_func_2(user_id_2=self.user.id))
@convert_params((User, 'username', 'username'))
def good_func_3(user):
return user
self.assertEquals(self.user, good_func_3(username=self.user.username))
with self.assertRaises(Http404):
good_func_3(username='badusername')
@convert_params(User, prevent_404=True)
def good_func_4(user):
return user
self.assertEquals(self.user, good_func_4(user_id=self.user.id))
self.assertIsNone(good_func_4(user_id=9999))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment