Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@nettok
Created May 15, 2012 03:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nettok/2698893 to your computer and use it in GitHub Desktop.
Save nettok/2698893 to your computer and use it in GitHub Desktop.
WebServices module
import json
from functools import wraps
from collections import namedtuple
from pyramid.request import Request
from pyramid.httpexceptions import (
HTTPException,
HTTPNotFound,
HTTPMethodNotAllowed,
HTTPNotAcceptable
)
from pyramid.exceptions import PredicateMismatch
from webob.multidict import MultiDict, NestedMultiDict
from .data import decode_request_body
from .response import error
class WebService(object):
def __init__(self, pyramid_configurator, default_representations=None):
self.config = pyramid_configurator
if default_representations is None:
self.reprs = [('application/json', 'json')]
else:
self.reprs = default_representations
def make_resource(self, name, path, default_representations=None):
return Resource(self, name, path, default_representations)
class Resource(object):
def __init__(self, webservice, name, path, default_representations=None):
self.ws = webservice
self.name = name
self.path = path
if default_representations is None:
self.reprs = self.ws.reprs
else:
self.reprs = default_representations
self.ws.config.add_route(name, path)
def _check_request_init_errors(self, view):
"""this must be called before any other view wrappers"""
@wraps(view)
def wrapper(req):
if req.has_errors():
return error(req)
return view(req)
return wrapper
def _validate(self, view, validators):
@wraps(view)
def wrapper(req):
for validator in validators:
validator(req)
if req.has_errors():
return error(req, '400 Request validation failed')
return view(req)
return wrapper
def _add_view(self, view, request_method, representations=None, validators=None, **kw):
for disallow_kw in ('accept', 'renderer'):
if disallow_kw in kw:
raise TypeError('`{0}` not allowed as a keyword argument. '
'Must be passed as part of `representations`.'.format(disallow_kw))
view = self.ws.config.maybe_dotted(view)
if validators is not None:
if not hasattr(validators, '__iter__'):
validators = (validators,)
validators = map(self.ws.config.maybe_dotted, validators)
else:
validators = []
view_validators_attr = getattr(view, 'validators', None)
if view_validators_attr is not None:
validators.extend(view_validators_attr)
if len(validators) > 0:
view = self._validate(view, validators)
# make sure this is the outermost wrapper of `view`
view = self._check_request_init_errors(view)
if representations is None:
representations = self.reprs
for rep in representations:
accept, renderer = rep
self.ws.config.add_view(
view,
route_name=self.name,
request_method=request_method,
accept=accept,
renderer=renderer,
**kw
)
def get(self, view, **kw):
self._add_view(view, 'GET', **kw)
def post(self, view, **kw):
self._add_view(view, 'POST', **kw)
def put(self, view, **kw):
self._add_view(view, 'PUT', **kw)
def delete(self, view, **kw):
self._add_view(view, 'DELETE', **kw)
class Errors(list):
Error = namedtuple('Error', ['location', 'name', 'description'])
def __init__(self, request, httpexception=None):
super(Errors, self).__init__()
self.request = request
self.httpexception = httpexception
def add(self, location, name=None, description=None):
self.append(Errors.Error(location, name, description))
def add_querystring(self, name, description=None):
self.add('querystring', name, description)
def add_headers(self, name, description=None):
self.add('headers', name, description)
def add_body(self, name, description=None):
self.add('body', name, description)
def add_path(self, name, description=None):
self.add('path', name, description)
class WSRequest(Request):
def __init__(self, environ, **kw):
super(WSRequest, self).__init__(environ, **kw)
self.decoded_body = None
self.valid = dict()
self.errors = Errors(self)
self.DECODED = MultiDict()
try:
if self.body:
content_type = self.headers.get('content-type', None)
if content_type:
form_content_types = (
'application/x-www-form-urlencoded',
'multipart/form-data'
)
if any(map(content_type.startswith, form_content_types)):
self.decoded_body = self.POST
else:
self.decoded_body = decode_request_body(self)
self.DECODED.update(self.decoded_body)
else:
self.errors.add_body('**decoding**', 'Content-Type header not found')
except Exception as e:
if isinstance(e, HTTPException):
self.errors.httpexception = e
self.errors.add_body('**decoding**', str(e))
@property
def params(self):
return NestedMultiDict(self.GET, self.POST, self.DECODED)
def has_errors(self):
return len(self.errors) > 0
def notfound(req):
if req.matched_route is None:
return HTTPNotFound()
introspector = req.registry.introspector
routes = introspector.get('routes', req.matched_route.name)
views_info = {}
for view_intr in introspector.related(routes):
rms = view_intr.get('request_methods')
accept = view_intr.get('accept')
for rm in rms:
if views_info.get(rm) is None:
views_info[rm] = set()
if accept is not None:
views_info[rm].add(accept)
if req.method not in views_info:
resp = HTTPMethodNotAllowed()
resp.allow = views_info.keys()
return resp
if req.accept.best_match(views_info[req.method]) is None:
resp = HTTPNotAcceptable()
if req.method != 'HEAD':
resp.body = json.dumps(error(req, acceptable=list(views_info[req.method])))
resp.content_type = 'application/json'
return resp
raise PredicateMismatch(req.matched_route.name)
def includeme(config):
config.set_request_factory(WSRequest)
config.add_notfound_view(notfound)
################################################################
# Example: Basic usage
config = Configurator(settings=settings)
config.include('ws')
ws = WebService(config)
some_resource = ws.make_resource('some_resource', '/some_resource')
some_resource.get('views.someview')
some_resource.post('views.someview_post')
some_resource.delete('views.someview_delete')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment