Created
November 12, 2016 09:01
-
-
Save moriyoshi/d9fdc1ebf0479f137b71467e86804e7b to your computer and use it in GitHub Desktop.
Exclusive request method predicate for Pyramid (>= 1.7)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyramid.config import Configurator | |
from pyramid.config.predicates import RequestMethodPredicate | |
from pyramid.config.util import as_sorted_tuple | |
from pyramid.compat import string_types, text_type | |
from pyramid.response import Response | |
from pyramid.httpexceptions import HTTPMethodNotAllowed | |
from pyramid.view import view_config | |
from pyramid.viewderivers import predicated_view, INGRESS | |
from webtest import TestApp | |
Everything = None | |
class EverythingType(object): | |
def __contains__(self, item): | |
return True | |
def __new__(cls): | |
if Everything is None: | |
return super().__new__(cls) | |
return Everything | |
Everything = EverythingType() | |
class Exclusively(text_type): | |
def __new__(cls, wrapped): | |
if isinstance(wrapped, string_types): | |
self = super().__new__(cls, wrapped) | |
else: | |
self = super().__new__(cls) | |
self.wrapped = wrapped | |
return self | |
def __eq__(self, that): | |
return self.wrapped == that | |
def __gt__(self, that): | |
return self.wrapped > that | |
def __lt__(self, that): | |
return self.wrapped < that | |
def __contains__(self, item): | |
return item in self.wrapped | |
def exclusive_request_method_view_deriver(view, info): | |
request_method_pred = None | |
for pred in info.predicates: | |
if isinstance(pred, RequestMethodPredicate): | |
request_method_pred = pred | |
break | |
if request_method_pred is not None: | |
for x in request_method_pred.val: | |
if isinstance(x, Exclusively): | |
break | |
else: | |
x = None | |
if x is not None: | |
request_methods = as_sorted_tuple(x.wrapped) | |
if 'GET' in request_methods and 'HEAD' not in request_methods: | |
request_methods = as_sorted_tuple(request_methods + ('HEAD',)) | |
request_method_pred.val = Everything | |
def wrapped_view(context, request): | |
if request.method not in request_methods: | |
raise HTTPMethodNotAllowed(text='Predicate mismatch for view %s (request_method = %s)' % (getattr(view, '__name__', view), ','.join(request_methods))) | |
return view(context, request) | |
return wrapped_view | |
return view | |
@view_config(route_name='test1', request_method=Exclusively(['GET', 'POST'])) | |
def test_view1(context, request): | |
return Response('HEY') | |
@view_config(route_name='test2', request_method=['GET', 'POST']) | |
def test_view2(context, request): | |
return Response('HEY') | |
c = Configurator() | |
c.add_view_deriver(exclusive_request_method_view_deriver, under=INGRESS) | |
c.add_route('test1', '/1') | |
c.add_route('test2', '/2') | |
c.scan(__name__) | |
t = TestApp(c.make_wsgi_app()) | |
resp = t.get('/1') | |
assert resp.status_int == 200 | |
resp = t.post('/1') | |
assert resp.status_int == 200 | |
resp = t.put('/1', status=405) | |
assert resp.status_int == 405 | |
resp = t.get('/2') | |
assert resp.status_int == 200 | |
resp = t.post('/2') | |
assert resp.status_int == 200 | |
resp = t.put('/2', status=404) | |
assert resp.status_int == 404 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment