Last active
June 9, 2022 09:27
-
-
Save piglei/323fb465d2578076e97738e864e5ede8 to your computer and use it in GitHub Desktop.
A Django command prints DRF serializer's schema in Markdown table format.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A Django command prints DRF serializer's schema in Markdown table format. | |
## Tutorial | |
Make sure "drf_yasg" package is installed, place this file into your Django app's | |
`/management/commands` directory, run bellow command to print documentation for Serializer: | |
python manage.py slz_table_doc --name ModuleSLZ | |
""" | |
from collections import defaultdict | |
from dataclasses import dataclass | |
from typing import DefaultDict, Dict, Iterable, List, Optional, Type | |
from django.core.management.base import BaseCommand | |
from drf_yasg.inspectors.field import get_basic_type_info | |
from rest_framework.serializers import Serializer | |
TITLE_NAME = '参数名称' | |
TITLE_TYPE = '参数类型' | |
TITLE_REQUIRED = '必须' | |
TITLE_DESCRIPTION = '参数说明' | |
class Command(BaseCommand): | |
help = "Print a serializer's schema as Markdown table" | |
def add_arguments(self, parser): | |
parser.add_argument("-n", "--name", type=str, required=True, help="The name of serializer, e.g. ModuleSLZ") | |
parser.add_argument( | |
"--is-request-form", action='store_true', help='If True, include "required" and other infos' | |
) | |
def handle(self, name: str, is_request_form: bool, *args, **options): | |
slz_cls = find_serializer_class(name) | |
self.print_doc(get_doc_objects(slz_cls()), is_request_form) | |
def print_doc(self, fields: Dict[str, List], is_request_form: bool): | |
"""Print serializer fields data as markdown table""" | |
table_printer = MDTablePrinter() | |
if is_request_form: | |
cols = [TITLE_NAME, TITLE_REQUIRED, TITLE_TYPE, TITLE_DESCRIPTION] | |
else: | |
cols = [TITLE_NAME, TITLE_TYPE, TITLE_DESCRIPTION] | |
for key, value in fields.items(): | |
obj_name = f'子成员 `{key}` ' if key else '成员' | |
print(f'{obj_name}对象各字段说明: \n') | |
table_printer.print_col(cols) | |
table_printer.print_sep(len(cols)) | |
for doc_field in value: | |
if is_request_form: | |
table_printer.print_col( | |
[ | |
doc_field.name, | |
doc_field.get_required_display(), | |
doc_field.get_type_display(), | |
doc_field.description, | |
] | |
) | |
else: | |
table_printer.print_col([doc_field.name, doc_field.get_type_display(), doc_field.description]) | |
print() | |
def find_serializer_class(name: str) -> Type[Serializer]: | |
"""Find serializer class by name | |
:raise: ValueError when no Serializer type can be found by given name | |
""" | |
for t in get_subclasses(Serializer): | |
if t.__name__ == name: | |
return t | |
raise ValueError(name) | |
def get_subclasses(cls: Type) -> Iterable[Type]: | |
"""Get all subclass for given type recursively""" | |
for subclass in cls.__subclasses__(): | |
yield from get_subclasses(subclass) | |
yield subclass | |
def get_doc_objects(serializer: Serializer) -> Dict[str, list]: | |
"""Get all fields of given serializer for documentation""" | |
doc_fields: DefaultDict = defaultdict(list) | |
traverse_serializer(serializer, doc_fields) | |
return doc_fields | |
def traverse_serializer(serializer: Serializer, container: DefaultDict[str, list], path: str = ''): | |
"""Traverse a serializer, store all fields doc details into given container | |
:param container: holds all result data | |
:param path: current path, '' means root. | |
""" | |
for name, field in serializer.fields.items(): | |
if isinstance(field, Serializer): | |
traverse_serializer(field, container, path=f'{path}.{name}') | |
doc_field = FieldDocObj( | |
name=name, | |
required=field.required, | |
type=None, | |
description=f'详见之后的 {path}.{name} 对象说明', | |
) | |
container[path].append(doc_field) | |
else: | |
desc_texts = [] | |
# Add label to description when it's not identical with name | |
if field.label and not text_fuzzy_equal(field.label, name): | |
desc_texts.append(field.label) | |
if field.help_text: | |
desc_texts.append(field.help_text) | |
doc_field = FieldDocObj( | |
name=name, required=field.required, type=get_basic_type_info(field), description='; '.join(desc_texts) | |
) | |
container[path].append(doc_field) | |
def text_fuzzy_equal(s1: str, s2: str) -> bool: | |
"""Check it two string are same with each other""" | |
s1 = s1.lower().replace('_', ' ') | |
s2 = s2.lower().replace('_', ' ') | |
return s1 == s2 | |
@dataclass | |
class FieldDocObj: | |
"""A simple type stores data related with doc rendering""" | |
name: str | |
required: str | |
type: Optional[Dict] | |
description: str | |
def get_required_display(self) -> str: | |
return 'T' if self.required else 'F' | |
def get_type_display(self) -> str: | |
"""Return a simple formatted string of `type`, e.g. string(data-time)""" | |
if not self.type: | |
return '' | |
value = self.type.get('type', '') | |
if format := self.type.get('format'): | |
value += f'({format})' | |
return value | |
class MDTablePrinter: | |
"""Helper class for printing markdown table""" | |
def print_col(self, data): | |
ret = '| ' + ' | '.join(map(str, data)) + ' |' | |
print(ret) | |
def print_sep(self, length: int): | |
parts = ['|'] * (length + 1) | |
print('---'.join(parts)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment