Skip to content

Instantly share code, notes, and snippets.

@prschmid
Created January 14, 2016 18:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save prschmid/c56e4aa9da58e42b175f to your computer and use it in GitHub Desktop.
Save prschmid/c56e4aa9da58e42b175f to your computer and use it in GitHub Desktop.
Decorator to validate HTTP parameters and submitted JSON for Flask routes
from marshmallow import (
fields,
Schema)
def marshmallow_schema_to_dict(schema):
"""Convert a :class:`marshmallow.Schema` to a dict definition
:param schema: The :class:`marshmallow.Schema` to convert
:returns: A dict containing the details of the schema
"""
return {
'fields': [
{
name: {
'type': field.__class__.__name__,
'required': field.required,
'allow_none': field.allow_none,
}
}
for name, field in
schema._declared_fields.iteritems()]
}
def build_schema_from_dict(d, allow_nested=True):
"""Build a Marshmallow schema based on a dictionary of parameters
:param d: The dict of parameters to use to build the Schema
:param allow_nested: Whether or not nested schemas are allowed. If
``True`` then a fields.Nested() will be created
when there is a nested value.
:return: A Marshmallow schema based on the dictionary
"""
for k, v in d.iteritems():
if isinstance(v, tuple):
schema = v[0]
if len(v) > 1:
opts = v[1]
elif isinstance(v, dict):
schema = v
opts = {}
else:
continue
if not allow_nested:
raise ValueError("Nested attributes not allowed.")
# Recursively generate the nested schema(s)
schema = build_schema_from_dict(schema)
# Update the current dict with the Nested schema
d[k] = fields.Nested(schema, **opts)
return type('Schema', (Schema, ), d)
def ensure(params=None, input=None):
"""Decorator to validate HTTP parameters and submitted JSON.
Usage:
# Using this by just defining the attributes and no explicit schema
from marshmallow import field
@route('/foo')
@ensure(
input={
'bar': fields.Str(required=True),
'baz': fields.Str()
})
def foo():
pass
# Support for options on loading (e.g. loading a list with many=True)
@route('/foo')
@ensure(
input=(
{
'bar': fields.Str(required=True),
'baz': fields.Str(),
},
{
'many': True
})
def foo():
pass
# Support for nested "schemas"
@route('/foo')
@ensure(
input={
'bar': fields.Str(required=True),
'baz': {
'bam': fields.Str()
}
})
def foo():
pass
# Support for nested "schemas" with options
@route('/foo')
@ensure(
input={
'bar': fields.Str(required=True),
'baz': (
{
'bam': fields.Str()
},
{
'required': True,
'many': True
}
})
def foo():
pass
# Defining an explicit schema for the validation
from marshmallow import field, Schema
FooSchema(Schema):
bar = fields.Str(required=True)
baz = fields.Str()
@route('/foo')
@ensure(input=FooSchema)
def foo():
pass
Note: This method makes use of a context manager called `ignored` to ignore
expected exceptions.
@contextmanager
def ignored(*exceptions):
try:
yield
except exceptions:
pass
:param params: The input :class:`marshmallow.Schema` or a dict of the fields
of the schema to use for the request parameters
:param input: The input :class:`marshmallow.Schema` or a dict of the fields
of the schema to use for the input JSON data
"""
# A simple named tuple to keep track of schemas and their loading options
SchemaDefinition = namedtuple("SchemaDefinition", ["schema", "options"])
# Convert the input fields into Schemas if the were provided
# as dictionaries
schemas = {
'params': params,
'input': input
}
for name, schema in schemas.iteritems():
if not schema:
continue
# Hack to make sure we don't have nested input parameter definitions
allow_nested = True
if name == 'params':
allow_nested = False
if isinstance(schema, tuple):
if isinstance(schema[0], dict):
schemas[name] = SchemaDefinition(
build_schema_from_dict(
copy.deepcopy(schema[0]), allow_nested=allow_nested),
schema[1])
else:
schemas[name] = SchemaDefinition(
schema[0],
schema[1])
elif isinstance(schema, dict):
schemas[name] = SchemaDefinition(
build_schema_from_dict(
copy.deepcopy(schema), allow_nested=allow_nested),
{})
else:
schemas[name] = SchemaDefinition(
schema,
{})
def load(args, schema, options={}):
"""Perform the loading of the data into the given schema
:param args: The arguments provided by the user from the endpoint
:param schema: The :class:`marshmallow.Schema` class to load the data
into
:param options: A dict of options to pass to the load() method
"""
# Remove the 'required' param from the options since
# that is not a top level option... but we still want to
# allow a user to validate against that
required = options.get('required', False)
with ignored(KeyError):
del options['required']
# If we allow many, but only a singleton was provided, convert the
# input args to a list
if options.get('many', False) and not isinstance(args, list):
args = [args]
data, errors = schema().load(
args,
**options)
if required and not data:
raise ValueError("No data provided")
return data, errors
def wrap(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
data = None
# Ensure the URL parameters
if schemas['params']:
try:
data, errors = load(
request.args,
schemas['params'].schema,
schemas['params'].options)
except (ValueError, ValidationError) as exc:
errors = {
'parsing':
"Could not validate HTTP arguments. {}".format(
exc.message)
}
except Exception:
errors = {'parsing': "Could not validate HTTP arguments"}
if errors:
raise BadRequestApiError(
message=errors,
schema=marshmallow_schema_to_dict(
schemas['params'].schema))
else:
request.params = data
else:
request.params = {}
# Ensure the input
if schemas['input']:
require_json_content_type()
input = {}
# Don't fail on requests with no JSON
with ignored(BadRequest):
input = request.json
try:
data, errors = load(
input,
schemas['input'].schema,
schemas['input'].options)
except (ValueError, ValidationError) as exc:
errors = {
'parsing':
"Could not validate input. {}".format(exc.message)
}
except Exception:
errors = {'parsing': "Could not validate input"}
if errors:
raise BadRequestApiError(
message=errors,
schema=marshmallow_schema_to_dict(
schemas['input'].schema))
else:
request.input = data
# Ok, do what we came to do
return f(*args, **kwargs)
return wrapper
return wrap
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment