Created
February 22, 2021 17:59
-
-
Save olegborzov/b9d08c830fdada1bee2e6e1abbbb0032 to your computer and use it in GitHub Desktop.
Marshmallow to Sanic-Openapi converter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from marshmallow import fields, Schema | |
from sanic_openapi import doc | |
MAP_MARSHMALLOW_IN_OPEN_API = { | |
fields.DateTime: doc.DateTime, | |
fields.Date: doc.Date, | |
fields.Integer: doc.Integer, | |
fields.String: doc.String, | |
fields.Boolean: doc.Boolean, | |
fields.List: doc.List | |
} | |
def open_api_schemas(schema: Schema, is_json: bool = False, many: bool = False) -> doc.Field: | |
""" | |
Generate swagger doc from marshmallow schema | |
Example: | |
from marshmallow import Schema | |
from sanic_openapi import doc | |
class DictResponseScheme(doc.Dictionary): | |
def __init__(self, result_scheme: doc.Field, description: Optional[str] = None): | |
fields = { | |
"success": doc.Boolean(required=True), | |
"result": result_scheme | |
} | |
super().__init__(fields, description=description) | |
class CreateSiteSchema(Schema): | |
domain = fields.String(required=True, allow_none=False) | |
name = fields.String(required=False, allow_none=True) | |
page_urls = fields.List(fields.String(), required=False) | |
class SiteSchema(Schema): | |
class PageSchema(Schema): | |
page_id = fields.Integer(required=True, allow_none=False) | |
url = fields.String(required=True, allow_none=False) | |
site_id = fields.Integer(required=True, allow_none=False) | |
domain = fields.String(required=True, allow_none=False) | |
name = fields.String(required=False, allow_none=True) | |
pages = fields.Nested(PageSchema, many=True, required=False) | |
@doc.summary('Add web site') | |
@doc.consumes( | |
open_api_schemas(CreateSiteSchema(), is_json=True), location='body', | |
content_type="application/json", required=True | |
) | |
@doc.produces(DictResponseScheme(open_api_schemas(SiteSchema())), content_type="application/json") | |
async def add_web_site(request): | |
data = CreateSiteSchema().load(request.json) | |
site = await create_site(data) | |
site_result = SiteSchema().dump(site) | |
return response.json({ | |
'success': success, | |
'result': site_result | |
}) | |
@doc.summary('Get all web-sites') | |
@doc.produces(DictResponseScheme(open_api_schemas(SiteSchema(), many=True)), content_type="application/json") | |
async def get_all_sites(request): | |
sites = await get_sites() | |
sites_result = SiteSchema(many=True).dump(site) | |
return response.json({ | |
'success': success, | |
'result': sites_result | |
}) | |
:param schema: marshmallow schema object | |
:param is_json: return doc.JsonBody | |
:param many: return doc.List | |
:return: openapi Field-nested object | |
""" | |
swagger_schema = {} | |
for field_name in schema.fields: | |
field_schema = schema.fields[field_name] | |
swagger_schema[field_name] = marshmallow_to_swagger(field_name, field_schema) | |
if schema.many is True or many: | |
return doc.List(swagger_schema) | |
if is_json: | |
return doc.JsonBody(swagger_schema) | |
return swagger_schema | |
def open_api_schemas_params(schema: Schema) -> List[doc.Field]: | |
""" | |
Generate swagger doc for query-parameters from marshmallow schema | |
Example: | |
from marshmallow import Schema | |
from sanic_openapi import doc | |
class FindSitesArgsSchema(Schema): | |
site_ids = fields.String(required=True, allow_none=False) | |
active = fields.Boolean(required=False, allow_none=False, missing=True) | |
@doc.summary('Find web-sites') | |
@doc.consumes(*open_api_schemas_params(FindSitesArgsSchema()), location='query', required=True) | |
@doc.produces(DictResponseScheme(open_api_schemas(SiteSchema(), many=True)), content_type="application/json") | |
async def find_sites(request): | |
sites = await get_sites() | |
sites_result = SiteSchema(many=True).dump(site) | |
return response.json({ | |
'success': success, | |
'result': sites_result | |
}) | |
:param schema: marshmallow schema object | |
:return: openapi Field-nested objects list | |
""" | |
params = [] | |
for field_name in schema.fields: | |
field_schema = schema.fields[field_name] | |
params.append(marshmallow_to_swagger(field_name, field_schema)) | |
return params | |
def marshmallow_to_swagger(field_name: str, field_schema: fields.Field) -> doc.Field: | |
""" | |
Convert marshmallow to openapi schema | |
:param field_name: field's name | |
:param field_schema: marshmallow type | |
:return: openapi doc | |
""" | |
field_type = type(field_schema) | |
if isinstance(field_schema, fields.Nested): | |
return open_api_schemas(field_schema.schema) | |
elif isinstance(field_schema, fields.List): | |
doc_type = MAP_MARSHMALLOW_IN_OPEN_API.get(type(field_schema.container), doc.String) | |
return doc.List(doc_type, name=field_name, required=field_schema.required) | |
else: | |
doc_type = MAP_MARSHMALLOW_IN_OPEN_API.get(field_type, doc.String) | |
swagger_type = doc_type(name=field_name, required=field_schema.required) | |
return swagger_type |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment