Skip to content

Instantly share code, notes, and snippets.

@sergeyenin
Forked from nonamenix/hug_swagger.py
Created May 23, 2018 14:32
Show Gist options
  • Save sergeyenin/fbf5ed3653c4ff99bb526714b6f33054 to your computer and use it in GitHub Desktop.
Save sergeyenin/fbf5ed3653c4ff99bb526714b6f33054 to your computer and use it in GitHub Desktop.
Swagger specification for hug. Ugly draft.
import inspect
import collections
from collections import OrderedDict
import hug
import logging
from apispec import APISpec
from apispec.ext.marshmallow.swagger import field2parameter
from copy import copy
from defaultsettings import DefaultSettings
import importlib
from marshmallow import fields, Schema
from marshmallow.schema import SchemaMeta
from hug_extensions.hug_swagger.testingschemas import TestingSchema
from . import swagger
logger = logging.getLogger(__name__)
class Settings(DefaultSettings):
HOST = 'localhost:9001'
SCHEMES = ['http']
VERSION = '0.1'
TITLE = 'Swagger Application'
DEFINITIONS_PATH = None
TESTING_MODE = False
settings = Settings('SWAGGER_')
del Settings
def get_summary(description):
return description.split('\n')[0]
def where_is_parameter(name, url):
# TODO: body, header
return 'path' if '{%s}' % name in url else 'query'
def get_parameters(url, interface):
defaults = interface.defaults
sig = inspect.signature(interface.interface.spec)
parameters = {}
for name in interface.parameters:
parameter_type = sig.parameters[name].annotation
if getattr(parameter_type, 'directive', False):
logger.info('Skip directive: %s for url: %s ', name, url)
continue
if parameter_type != inspect.Parameter.empty:
# path and query
if isinstance(parameter_type, fields.Field):
parameter_place = where_is_parameter(name, url)
parameter_type.metadata = {'location': where_is_parameter(name, url)}
parameter_type.required = name not in defaults
parameter = field2parameter(parameter_type, name=name, default_in=parameter_place)
if name in defaults:
parameter['default'] = defaults[name]
parameters[name] = parameter
# body
elif name == 'body' and (isinstance(parameter_type, Schema) or isinstance(parameter_type, SchemaMeta)):
if isinstance(parameter_type, Schema):
schema_name = parameter_type.__class__.__name__
elif isinstance(parameter_type, SchemaMeta):
schema_name = parameter_type.__name__
ref_definition = "#/definitions/{}".format(schema_name)
ref_schema = {"$ref": ref_definition}
parameters['body'] = {
"in": "body",
"name": "body",
"required": True,
"schema": ref_schema
}
else:
logger.error('Use marshmallow fields in url: %s instead of hug: %s %s', url, name, parameter_type)
else:
# logger.info('There is no type annotation for %s in url: %s', name, url)
pass
return parameters
def get_operation_and_define_response_schemas(interface, spec):
handler = interface.interface.spec
sig = inspect.signature(handler) # type: Signature
annotated_response_schema = sig.return_annotation
responses = copy(getattr(handler, 'swagger_responses', OrderedDict()))
if annotated_response_schema != inspect.Parameter.empty:
responses.setdefault(200, {})['schema'] = annotated_response_schema
for code, response in responses.items():
response = copy(response)
try:
schema = response['schema']
if isinstance(schema, str): # schema name provided
name = schema
elif isinstance(schema, Schema): # schema instance provided
name = schema.__class__.__name__
spec.definition(name, schema=schema)
elif isinstance(schema, SchemaMeta): # schema class provided
name = schema.__name__
spec.definition(name, schema=schema())
else:
logger.error('Wrong response schema %s', schema)
schema = None
except KeyError:
pass
else:
if schema is not None:
ref_name = '#/definitions/{}'.format(name)
ref_schema = {'$ref': ref_name}
response["schema"] = ref_schema
responses[code] = response
return responses
@hug.get('/swagger.json')
def swagger_json(hug_api):
spec = APISpec(
title=settings.TITLE,
version=settings.VERSION,
plugins=(
'apispec.ext.marshmallow',
),
schemes=settings.SCHEMES,
host=settings.HOST
)
if settings.DEFINITIONS_PATH is not None:
definitions = importlib.import_module(settings.DEFINITIONS_PATH)
for name, schema in definitions.__dict__.items(): # type: str, Schema
if name.endswith('Schema') and len(name) > len('Schema'):
spec.definition(name, schema=schema)
routes = hug_api.http.routes['']
for url, route in routes.items():
for method, versioned_interfaces in route.items():
for versions, interface in versioned_interfaces.items():
methods_data = {}
documentation = interface.documentation()
methods_data['content_type'] = documentation['outputs']['content_type']
try:
methods_data['summary'] = get_summary(documentation['usage'])
methods_data['description'] = documentation['usage']
except KeyError:
pass
parameters = get_parameters(url, interface)
if parameters:
methods_data['parameters'] = parameters
for name, parameter in parameters.items():
spec.add_parameter(name, parameter['in'], **parameter)
responses = get_operation_and_define_response_schemas(interface, spec)
if responses:
methods_data['responses'] = responses
if not isinstance(versions, collections.Iterable):
versions = [versions]
for version in versions:
versioned_url = '/v{}{}'.format(version, url) if version else url
spec.add_path(versioned_url, operations={
method.lower(): methods_data
})
return spec.to_dict()
if settings.TESTING_MODE:
@hug.get('/swagger/hug/{hug_types_number}/{hug_types_greater_than_5}/')
def openapi_test(
request,
hug_timer,
hug_types_number: hug.types.number,
hug_types_greater_than_5: hug.types.GreaterThan(5)):
"""Endpoint with hug.types
Not versioned api method"""
return {
'hug_types_number': hug_types_number,
'hug_types_greater_than_5': hug_types_greater_than_5
}
@hug.get('/swagger/marshmallow/{swagger_types_number}/', versions=[2, 3])
@swagger.response(200, description='Good response', schema=TestingSchema)
@swagger.response(400, description='Bad response')
def openapi_test_swagger_types(
request,
hug_timer,
swagger_types_number: fields.Integer(),
swagger_types_number_in_query: fields.Integer() = 3) -> TestingSchema():
"""Endpoint with marshmallow types"""
return {
'swagger_types_number': swagger_types_number,
'swagger_types_number_in_query': swagger_types_number_in_query
}
@hug.post('/swagger/marshmallow/post-body') # TODO: check with last slash
@swagger.response(200, description='Created', schema=TestingSchema())
def openapi_post_body(body: TestingSchema()) -> TestingSchema():
return body
from marshmallow import Schema, fields
class TestingFieldsSchema(Schema):
integer = fields.Integer()
float = fields.Float()
boolean = fields.Boolean()
datetime = fields.DateTime()
timedelta = fields.TimeDelta()
dictionary = fields.Dict()
url = fields.Url()
email = fields.Email()
class TestingSchema(Schema):
hug_types_number = fields.Integer()
hug_types_greater_than_5 = fields.Integer()
hug_types_in_range_1_5 = fields.Integer()
@jshwelz
Copy link

jshwelz commented Jan 29, 2019

Did it work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment