Skip to content

Instantly share code, notes, and snippets.

@flayman
Created July 12, 2019 10:13
Show Gist options
  • Save flayman/e92d88e63e45a0d8b6cd1c41c48dc6c0 to your computer and use it in GitHub Desktop.
Save flayman/e92d88e63e45a0d8b6cd1c41c48dc6c0 to your computer and use it in GitHub Desktop.
My first crack at flask_restplus SchemaModel limited support for @api.marshal_with
import copy
import re
from six import iteritems
from collections import OrderedDict, MutableMapping
from flask_restplus.model import RawModel
from flask_restplus.fields import *
from flask_restplus.fields import get_value
class SchemaModel(RawModel, dict, MutableMapping):
'''
Stores API doc metadata based on a json schema.
:param str name: The model public name
:param dict schema: The json schema we are documenting
'''
allow_null = None
skip_none = None
class JsonTypeParser(object):
'''
Resolves the given json schema fragment into a field derived from fields.Raw
:param object obj: The json schema fragment
'''
def __init__(self, obj, allow_null=False, skip_none=False):
self.jsonObject = obj
self.allow_null = allow_null
self.skip_none = skip_none
typeName = self.jsonObject.get('type', 'default')
clazz = getattr(self, "{0}_to_field".format(typeName), self.default_to_field)
self.parser = clazz(self)
def __call__(self):
return self.parser()
class default_to_field(object):
fieldClass = Raw
def __init__(self, container):
self.container = container
def makeArguments(self):
# Overridden in Nested and List types, which have expected positional parameters.
return []
def makeKeywordArguments(self):
kwargs = {}
try:
instance = self.fieldClass(None)
except MarshallingError:
instance = self.fieldClass(Raw)
except TypeError:
instance = self.fieldClass()
# We'll copy the relevant attributes even though they don't really matter (apart from default).
# Others don't affect marshalling and validation uses the json schema directly.
# Useful perhaps for reconstructing a schema out of the fields. Fidelity check?
for key, val in iteritems(self.container.jsonObject):
if key == 'multipleOf':
key = 'multiple'
if hasattr(instance, key):
if key.endswith('imum'):
# remove that suffix (min, max, exclusiveMin, exclusiveMax are valid)
key = key[:-4]
kwargs.update([(key, val)])
else:
# convert from camelHump style to underscore_style and try again
converted = re.sub("([A-Z])", r'_\1', key).lower()
if converted != key and hasattr(instance, converted):
kwargs.update([(converted, val)])
return kwargs
def __call__(self):
args = self.makeArguments()
kwargs = self.makeKeywordArguments()
return self.fieldClass(*args, **kwargs)
class array_to_field(default_to_field):
fieldClass = List
def makeArguments(self):
args = super().makeArguments()
items = self.container.jsonObject.get('items', {})
if isinstance(items, list):
# this isn't really supported yet
self.fieldClass = Nested
fields = []
for item in items:
fields.append( type(self.container)(item)() )
args.append(fields)
else:
args.append( type(self.container)(items)() )
return args
class object_to_field(default_to_field):
fieldClass = Nested
def makeArguments(self):
args = super().makeArguments()
children = self.container.jsonObject.get('properties', {})
fields = {}
for childName, obj in iteritems(children):
fields[childName] = type(self.container)(obj)()
args.append(fields)
return args
def makeKeywordArguments(self):
kwargs = super().makeKeywordArguments()
kwargs.update([('allow_null', self.container.allow_null), \
('skip_none', self.container.skip_none)])
return kwargs
class string_to_field(default_to_field):
fieldClass = String
def date_time(self):
return DateTime
def date(self):
return Date
def __init__(self, container):
super().__init__(container)
format = container.jsonObject.get('format', None)
if format:
format = format.replace('-', '_')
if hasattr(self, format):
self.fieldClass = getattr(self, format)()
class number_to_field(default_to_field):
fieldClass = Float
class integer_to_field(number_to_field):
fieldClass = Integer
class boolean_to_field(default_to_field):
fieldClass = Boolean
class null_to_field(default_to_field):
pass
wrapper = dict
_schema = {}
def __init__(self, name, schema=None, allow_null=False, skip_none=False, *args, **kwargs):
self._schema = schema or {}
self.allow_null = allow_null
self.skip_none = skip_none
super(SchemaModel, self).__init__(name)
properties = None
if isinstance(self._schema, dict):
properties = self._schema.get('properties', {})
for name, json in iteritems(properties):
parser = self.JsonTypeParser(json, allow_null, skip_none)
self[name] = parser()
#print(self[name])
if not properties:
parser = self.JsonTypeParser(self._schema, allow_null, skip_none)
self[''] = parser()
def __deepcopy__(self, memo):
obj = self.__class__(self.name, self._schema, self.allow_null, self.skip_none,
[(key, copy.deepcopy(value, memo)) for key, value in iteritems(self)],
mask=self.__mask__)
obj.__parents__ = self.__parents__
return obj
def __unicode__(self):
return 'SchemaModel({name},{schema})'.format(name=self.name, schema=self._schema)
__str__ = __unicode__
# Monkey-patch existing flask_restplus model and flask_restplus.namespace
# provides some limited marshalling support
import flask_restplus.model as fr_model
fr_model.SchemaModel = SchemaModel
def redefine_schema_model(self, name=None, schema=None, allow_null=True, skip_none=False):
model = SchemaModel(name, schema, allow_null=allow_null, skip_none=skip_none)
return self.add_model(name, model)
import flask_restplus.namespace as fr_namespace
fr_namespace.Namespace.schema_model = redefine_schema_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment