Skip to content

Instantly share code, notes, and snippets.

@aweffr
Created September 30, 2020 08:15
Show Gist options
  • Save aweffr/719f08ce96b7244e43c8c1183636a687 to your computer and use it in GitHub Desktop.
Save aweffr/719f08ce96b7244e43c8c1183636a687 to your computer and use it in GitHub Desktop.
flask 装饰器: 自动把 query / body 里的参数转到 view_function 声明的参数里
from functools import wraps
import inspect
from inspect import Parameter
from flask import Flask, request, jsonify
from typing import Optional
# noinspection PyUnresolvedReferences,PyProtectedMember
empty = inspect._empty
class InjectParamsError(Exception):
pass
class UnsupportedParamType(Exception):
pass
class ParamIsRequiredException(Exception):
pass
def _check_can_inject_params(func):
params = inspect.signature(func).parameters
for name, param in params.items():
if param.annotation is empty:
raise InjectParamsError('使用 inject_params 的函数必须标注所有参数类型!')
def _is_union_none(param: Parameter):
s = str(param.annotation)
return 'typing.Union' in s and 'None' in s
def _is_required(param: Parameter):
if _is_union_none(param):
return False
if param.default is not empty:
return False
return True
def _extract_value(param: Parameter, value_from_request: dict):
name = param.name
if name in value_from_request:
if param.annotation in (int, str, dict):
return param.annotation(value_from_request[name])
elif _is_union_none(param):
_type = param.annotation.__args__[0]
if value_from_request[name] is None:
return None
return _type(value_from_request[name])
else:
raise UnsupportedParamType('不支持的类型标注:' + str(param.annotation))
else:
if param.default is not empty:
return param.default
elif _is_union_none(param):
return None
elif _is_required(param):
raise ParamIsRequiredException(f'params {param.name} is required!')
def inject_params(func):
@wraps(func)
def wrapper(*args, **kwargs):
_check_can_inject_params(func)
func_params = inspect.signature(func).parameters
values_from_request = dict(kwargs)
values_from_request.update(request.values)
if request.is_json:
values_from_request.update(request.get_json())
err_msgs = []
for key, param in func_params.items():
if key in kwargs:
pass
try:
kwargs[key] = _extract_value(param, values_from_request)
print(f'extract {key} as', kwargs[key])
except ParamIsRequiredException as e:
err_msgs.append(str(e))
if len(err_msgs) > 0:
return jsonify({'msg': '\n'.join(err_msgs), 'code': 400}), 400
return func(*args, **kwargs)
return wrapper
if __name__ == '__main__':
app = Flask(__name__)
@app.route('/items/<int:id>', methods=['GET', 'POST'])
@inject_params
def test_view(id: int, author: str, tag: Optional[str], search: str = '', limit: int = 0, offset: int = 10):
return jsonify({
'id': id,
'author': author,
'tag': tag,
'search': search,
'limit': limit,
'offset': offset,
})
app.run(debug=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment