Skip to content

Instantly share code, notes, and snippets.

@smahs
Last active March 28, 2016 11:49
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save smahs/f9d68d1f61869f535301 to your computer and use it in GitHub Desktop.
Save smahs/f9d68d1f61869f535301 to your computer and use it in GitHub Desktop.
Micro REST Framework for Django: A Prototype
from django.views.generic import View
from django.http import (
QueryDict, HttpResponse, HttpResponseBadRequest,
HttpResponseNotFound, HttpResponseNotAllowed,
)
from django.core.serializers.json import DjangoJSONEncoder
from django.core.exceptions import (
ObjectDoesNotExist, ValidationError,
)
from django.db import transaction
from django.db.models import (
Model, ForeignKey, ManyToManyField, Q, FieldDoesNotExist,
)
from django.db.models.fields.related import RelatedField
from json import dumps, loads
from collections import defaultdict
from functools import partial
class Unauthorised(Exception):
pass
class InputValidation(object):
"""
A descriptor class for input validation decorator.
Called after View.dispatch but before the handler funcs.
"""
def __init__(self, f):
self._func = f
def get_data(self):
if self.request.method == 'GET':
self.data = self.request.GET
return self.parse_form(self.view.model_class, self.data.keys())
else:
try:
data = loads(self.request.body)
return self.parse_data(self.view.model_class, data)
except (TypeError, ValueError):
raise ValidationError({'input': ['parsing failed']})
def parse_data(self, klass, hmap):
"""
All HTTP methods except GET are supposed to send the object data
in their request.body. This data is expected as JSON, following the
same structure as their model class.
"""
local, rels = self.view.__class__.split_params(hmap)
params = dict()
for key in local:
try:
field = klass._meta.get_field(key)
if isinstance(field, RelatedField):
raise FieldDoesNotExist()
params[key] = field.to_python(local[key])
except (FieldDoesNotExist, ValidationError, ValueError):
raise ValidationError({key: ['invalid input data']})
for key in rels:
try:
field = klass._meta.get_field(key)
if not isinstance(field, RelatedField):
raise FieldDoesNotExist()
vals = ([rels.get(key)] if not isinstance(
rels.get(key), list) else rels.get(key))
store = [self.parse_data(field.related.parent_model, i)
for i in vals]
params[key] = store[0] if len(store) == 1 else store
except FieldDoesNotExist:
raise ValidationError({key: ['invalid input data']})
return params
def parse_form(self, klass, keys, sup=None):
"""
GET data come as URLEncoded, however the query fields should
follow Django's '__' notation for model relationships.
For example: comment=0&user__id=1 will give all comments for user 1.
"""
local, rels = self.view.__class__.split_fields(keys)
rels = self.view.__class__.tokenize(rels)
params = dict()
for key in local:
supkey = sup + '__' + key if sup else key
vals = self.data.getlist(supkey)
try:
field = klass._meta.get_field(key)
if isinstance(field, RelatedField):
raise FieldDoesNotExist()
params[key] = [field.to_python(i) for i in vals]
except (FieldDoesNotExist, ValidationError, ValueError):
raise ValidationError({key: ['Invalid input data']})
for key in rels:
try:
field = klass._meta.get_field(key)
if not isinstance(field, RelatedField):
raise FieldDoesNotExist()
sup = sup + '__' + key if sup else key
params[key] = self.parse_form(field.related.parent_model,
rels.get(key), sup=sup)
except FieldDoesNotExist:
raise ValidationError({key: ['Invalid input data']})
return params
def check_pk(self, field, data):
try:
assert field.name in data
except AssertionError:
raise ValidationError({'pk': ['validation failed']})
def validate_pk(self, klass, data):
"""
Validates the presence of primary keys in the input data.
TODO: Allow a way to skip this validation.
"""
local_pk = klass._meta.pk
local_rels = [i for i in klass._meta.local_fields
if isinstance(i, RelatedField)]
m2m_rels = klass._meta.local_many_to_many
if self.request.method is not 'POST':
self.check_pk(local_pk, data)
for field in local_rels + m2m_rels:
kls = field.related.parent_model
val = data.get(field.name, None)
if isinstance(val, list):
for datum in val:
self.check_pk(kls._meta.pk, datum)
elif isinstance(val, dict):
self.check_pk(kls._meta.pk, val)
else:
continue
def __call__(self, *args, **kwargs):
self.view = args[0]
self.request = args[1]
try:
params = self.get_data()
self.validate_pk(self.view.model_class, params)
setattr(self.view, 'params', params)
return self._func(*args, **kwargs)
except ValidationError as ve:
message = self.view.cleaning_errors(ve)
return HttpResponseBadRequest(message)
def __get__(self, obj, objtype):
return partial(self.__call__, obj)
class RestBaseView(View):
"""
A boilerplate class for providing some nice
extentions to Django's View class.
"""
def __init__(self, *args, **kwargs):
if not (self.model_class and Model in self.model_class.mro()):
raise TypeError('Model class not defined or not supported')
super(RestBaseView, self).__init__(*args, **kwargs)
def dispatch(self, request, *args, **kwargs):
if request.method not in self.methods:
return HttpResponseNotAllowed('Method not allowed')
try:
if hasattr(self, 'auth_class'):
self.auth_class().process_request(request)
except Unauthorised:
return HttpResponse(status=401)
return super(RestBaseView, self).dispatch(request, *args, **kwargs)
"""
Class methods for generic algorithms
"""
@classmethod
def tokenize(cls, arr, sep='__'):
d = defaultdict(list)
for i in arr:
j, k = i.split(sep, 1)
d[j].append(k)
return d
@classmethod
def split_fields(cls, fields_list):
local = [i for i in fields_list if '__' not in i]
rels = [i for i in fields_list if '__' in i]
return (local, rels)
@classmethod
def split_params(cls, params):
local = {k: v for k, v in params.iteritems()
if not hasattr(v, '__iter__')}
rels = {k: v for k, v in params.iteritems()
if hasattr(v, '__iter__')}
return (local, rels)
@classmethod
def flatten(cls, hmap):
def process():
for key, value in hmap.iteritems():
if isinstance(value, dict):
for subkey, subvalue in cls.flatten(value).iteritems():
yield key + "__" + subkey, subvalue
elif isinstance(value, list):
collector = []
for subvalue in value:
if hasattr(subvalue, '__iter__'):
collector.append(cls.flatten(subvalue))
else:
collector.append(subvalue)
yield key, collector
else:
yield key, value
return dict(process())
"""
Utility methods for View classes
"""
def send_json(self, data):
return HttpResponse(dumps(data, cls=DjangoJSONEncoder),
content_type='application/json')
def send_error(self, code, message):
return HttpResponse(dumps({'error': message}), status=code,
content_type='application/json')
def cleaning_errors(self, exc):
if isinstance(exc, ValidationError):
return '\n'.join([k + ': ' + ' '.join(v) for k, v
in exc.message_dict.iteritems()])
"""
Serialization methods
"""
def get_serializable(self, val):
try:
dumps(val, cls=DjangoJSONEncoder)
return val
except (ValueError, TypeError):
return str(val)
def serialize_local(self, obj, fields):
return {i: self.get_serializable(getattr(obj, i))
for i in fields}
def serialize_related(self, obj, fields):
attributes = RestBaseView.tokenize(fields)
out = dict()
for name, attr in attributes.iteritems():
field = obj.__class__._meta.get_field(name)
if isinstance(field, ForeignKey):
out.update(self.serialize_fk(obj, field, attr))
elif isinstance(field, ManyToManyField):
out.update(self.serialize_m2m(obj, field, attr))
return out
def serialize_fk(self, obj, field, attr):
val = field.value_from_object(obj)
if not val:
out = None
else:
if len(attr) == 1 and 'id' in attr:
out = {'id': val}
else:
sub = getattr(obj, field.name)
local, rels = RestBaseView.split_fields(attr)
out = self.serialize_local(sub, local)
out.update(self.serialize_related(sub, rels))
return {field.name: out}
def serialize_m2m(self, obj, field, attr):
local, rels = RestBaseView.split_fields(attr)
templ = getattr(obj, field.name).values_list(*local)
out = [dict(zip(local, i)) for i in templ]
if rels:
subs = getattr(obj, field.name).all()
tempr = [self.serialize_related(i, rels) for i in subs]
out = [dict(v, **tempr[i])for i, v in enumerate(out)]
return {field.name: out}
def serialize(self, obj):
local, rels = RestBaseView.split_fields(self.return_fields)
serialized = self.serialize_local(obj, local)
related = self.serialize_related(obj, rels)
if related:
serialized.update(related)
return serialized
"""
Database writes for relations
"""
def update_m2m(self, obj, field, vals):
attr = getattr(obj, field.name)
pkname = field.related.parent_model._meta.pk.name
ids = attr.values_list(pkname, flat=True)
intersect = set(vals).intersection(set(ids))
add = set(vals) - intersect
rem = set(ids) - intersect
for i in rem:
attr.remove(i)
for i in add:
attr.add(i)
def update_related(self, obj, rels):
for name, val in rels.iteritems():
field = obj.__class__._meta.get_field(name)
if isinstance(field, ForeignKey):
pkname = field.related_field.name
setattr(obj, name + '_id', val.get(pkname))
elif isinstance(field, ManyToManyField):
pkname = field.related.parent_model._meta.pk.name
if isinstance(val, list):
vals = [i.get(pkname) for i in val if pkname in i]
if isinstance(val, dict):
vals = [val.get(pkname)]
self.update_m2m(obj, field, vals)
"""
DB fetch, override to custom gets
"""
def get_records(self, params):
params = RestBaseView.flatten(params)
params = {k: v for k, v in params.items() if 0 not in v}
params = {k + '__in' if len(v) > 1 else k: v
for k, v in params.items()}
params = {k: v[0] if len(v) == 1 else v
for k, v in params.items()}
return self.model_class.objects.filter(Q(**params))
"""
HTTP methods
"""
@InputValidation
def get(self, request, *args, **kwargs):
try:
objs = self.get_records(self.params)
out = [self.serialize(i) for i in objs]
return self.send_json({self.model_class.__name__.lower(): out})
except ObjectDoesNotExist:
return HttpResponseNotFound('Object not found')
@InputValidation
@transaction.atomic()
def post(self, request, *args, **kwargs):
try:
local, rels = RestBaseView.split_params(self.params)
obj = self.model_class(**local)
if self.model_class._meta.local_many_to_many:
obj.save()
self.update_related(obj, rels)
obj.full_clean()
obj.save()
return self.send_json(self.serialize(obj))
except ValidationError as e:
return HttpResponseBadRequest(self.cleaning_errors(e))
@InputValidation
@transaction.atomic()
def put(self, request, *args, **kwargs):
try:
local, rels = RestBaseView.split_params(self.params)
ids = local.pop('id')
qparams = {'id': [ids]}
obj = self.get_records(qparams)
if not obj:
raise ObjectDoesNotExist()
obj = obj[0]
for key, val in local.iteritems():
setattr(obj, key, val)
self.update_related(obj, rels)
obj.full_clean()
obj.save()
return self.send_json(self.serialize(obj))
except ObjectDoesNotExist:
return HttpResponseNotFound('Object not found')
except ValidationError as e:
return HttpResponseBadRequest(self.cleaning_errors(e))
@InputValidation
def delete(self, request, *args, **kwargs):
try:
pkname = self.model_class._meta.pk.name
self.model_class.objects.filter(
pk=self.params.get(pkname)).delete()
return HttpResponse(self.send_json('Deletion successful'))
except ObjectDoesNotExist:
return HttpResponseNotFound('Object not found')
# ============== auth.py ===========================
class CustomAuth(object):
"""
Middleware style class, called before dispath
"""
def process_request(self, request):
auth_token = request.META.get('HTTP_AUTH')
if not auth_token:
raise Unauthorised()
# set request.user from token
# ============== models.py =========================
class Comment(models.Model):
"""
Example model class, for the classical blog example
"""
title = models.CharField(max_length=256)
body = models.TextField(null=True, blank=True)
owner = models.ForeignKey(User, blank=True)
def validate_user(self):
try:
return User.objects.get(pk=int(self.owner_id))
except ObjectDoesNotExist, TypeError, ValueError:
raise ValidationError({'owner_id': ['invalid data']})
def full_clean(self, *args, **kwargs):
self.owner_id = validate_user()
super(Comment, self).full_clean(*args, **kwargs)
# ============== views.py ============================
class CommentView(RestBaseView):
auth_class = CustomAuth
model_class = Comment
methods = ['GET', 'POST', 'PUT', 'DELETE']
return_fields = ['id', 'title', 'body', 'user_id']
# =============== tests.py ===========================
from django.utils import unittest
from django.test.client import Client
from django.core.urlresolvers import reverse
from django.contrib.auth.models import User
from json import dumps, loads
class CommentViewTests(unittest.TestCase):
def setUp(self):
self.user = User.objects.create_user(username='user',
email='user@domain.com', password='secret')
self.client = Client()
def headers(self):
return {
'content_type': 'application/json',
'auth': self.auth_token,
}
def test_denies_anonymous(self):
response = self.client.get(reverse('comment_view'))
self.assertEqual(response.status_code, 401)
def test_login(self):
payload = {
'username': 'user',
'password': 'pass',
}
response = self.client.post(reverse('login'), payload)
self.assertEqual(response.status_code, 200)
self.assertTrue(response.has_header('AUTH'))
self.auth_token(response.get('AUTH'))
def test_post_comment(self):
payload = dumps({
'title': 'A title',
'body': 'None',
'owner': {
'id': self.user.id,
},
})
response = self.client.post(reverse('comment_view'),
payload, **self.headers())
self.assertEqual(response.status_code, 200)
self.comment = loads(response.body)
def test_post_comment_bad(self):
payload = dumps({
'title': 'A title',
'body': 'None',
})
response = self.client.post(reverse('comment_view'),
payload, **self.headers())
self.assertEqual(response.status_code, 400)
def test_put_comment(self):
self.comment['body'] = 'Not None'
payload = dumps(self.comment)
response = self.client.post(reverse('comment_view'),
payload, **self.headers())
self.assertEqual(response.status_code, 200)
comment = loads(response.body).get('comment')[0]
self.assertEqual(comment.get('body'), self.comment.get('body'))
def test_get_comments_for_user(self):
payload = {
'id': 0,
'user__id': self.user.id,
}
response = self.client.get(reverse('comment_view'),
payload, **self.headers())
self.assertEqual(response.status_code, 200)
comments = loads(response.body)
assertTrue(self.comment.id in [comment.get('id')
for comment in comments.get('comment')])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment