Skip to content

Instantly share code, notes, and snippets.

@piglei
Last active June 9, 2022 09:27
Show Gist options
  • Save piglei/323fb465d2578076e97738e864e5ede8 to your computer and use it in GitHub Desktop.
Save piglei/323fb465d2578076e97738e864e5ede8 to your computer and use it in GitHub Desktop.
A Django command prints DRF serializer's schema in Markdown table format.
"""
A Django command prints DRF serializer's schema in Markdown table format.
## Tutorial
Make sure "drf_yasg" package is installed, place this file into your Django app's
`/management/commands` directory, run bellow command to print documentation for Serializer:
python manage.py slz_table_doc --name ModuleSLZ
"""
from collections import defaultdict
from dataclasses import dataclass
from typing import DefaultDict, Dict, Iterable, List, Optional, Type
from django.core.management.base import BaseCommand
from drf_yasg.inspectors.field import get_basic_type_info
from rest_framework.serializers import Serializer
TITLE_NAME = '参数名称'
TITLE_TYPE = '参数类型'
TITLE_REQUIRED = '必须'
TITLE_DESCRIPTION = '参数说明'
class Command(BaseCommand):
help = "Print a serializer's schema as Markdown table"
def add_arguments(self, parser):
parser.add_argument("-n", "--name", type=str, required=True, help="The name of serializer, e.g. ModuleSLZ")
parser.add_argument(
"--is-request-form", action='store_true', help='If True, include "required" and other infos'
)
def handle(self, name: str, is_request_form: bool, *args, **options):
slz_cls = find_serializer_class(name)
self.print_doc(get_doc_objects(slz_cls()), is_request_form)
def print_doc(self, fields: Dict[str, List], is_request_form: bool):
"""Print serializer fields data as markdown table"""
table_printer = MDTablePrinter()
if is_request_form:
cols = [TITLE_NAME, TITLE_REQUIRED, TITLE_TYPE, TITLE_DESCRIPTION]
else:
cols = [TITLE_NAME, TITLE_TYPE, TITLE_DESCRIPTION]
for key, value in fields.items():
obj_name = f'子成员 `{key}` ' if key else '成员'
print(f'{obj_name}对象各字段说明: \n')
table_printer.print_col(cols)
table_printer.print_sep(len(cols))
for doc_field in value:
if is_request_form:
table_printer.print_col(
[
doc_field.name,
doc_field.get_required_display(),
doc_field.get_type_display(),
doc_field.description,
]
)
else:
table_printer.print_col([doc_field.name, doc_field.get_type_display(), doc_field.description])
print()
def find_serializer_class(name: str) -> Type[Serializer]:
"""Find serializer class by name
:raise: ValueError when no Serializer type can be found by given name
"""
for t in get_subclasses(Serializer):
if t.__name__ == name:
return t
raise ValueError(name)
def get_subclasses(cls: Type) -> Iterable[Type]:
"""Get all subclass for given type recursively"""
for subclass in cls.__subclasses__():
yield from get_subclasses(subclass)
yield subclass
def get_doc_objects(serializer: Serializer) -> Dict[str, list]:
"""Get all fields of given serializer for documentation"""
doc_fields: DefaultDict = defaultdict(list)
traverse_serializer(serializer, doc_fields)
return doc_fields
def traverse_serializer(serializer: Serializer, container: DefaultDict[str, list], path: str = ''):
"""Traverse a serializer, store all fields doc details into given container
:param container: holds all result data
:param path: current path, '' means root.
"""
for name, field in serializer.fields.items():
if isinstance(field, Serializer):
traverse_serializer(field, container, path=f'{path}.{name}')
doc_field = FieldDocObj(
name=name,
required=field.required,
type=None,
description=f'详见之后的 {path}.{name} 对象说明',
)
container[path].append(doc_field)
else:
desc_texts = []
# Add label to description when it's not identical with name
if field.label and not text_fuzzy_equal(field.label, name):
desc_texts.append(field.label)
if field.help_text:
desc_texts.append(field.help_text)
doc_field = FieldDocObj(
name=name, required=field.required, type=get_basic_type_info(field), description='; '.join(desc_texts)
)
container[path].append(doc_field)
def text_fuzzy_equal(s1: str, s2: str) -> bool:
"""Check it two string are same with each other"""
s1 = s1.lower().replace('_', ' ')
s2 = s2.lower().replace('_', ' ')
return s1 == s2
@dataclass
class FieldDocObj:
"""A simple type stores data related with doc rendering"""
name: str
required: str
type: Optional[Dict]
description: str
def get_required_display(self) -> str:
return 'T' if self.required else 'F'
def get_type_display(self) -> str:
"""Return a simple formatted string of `type`, e.g. string(data-time)"""
if not self.type:
return ''
value = self.type.get('type', '')
if format := self.type.get('format'):
value += f'({format})'
return value
class MDTablePrinter:
"""Helper class for printing markdown table"""
def print_col(self, data):
ret = '| ' + ' | '.join(map(str, data)) + ' |'
print(ret)
def print_sep(self, length: int):
parts = ['|'] * (length + 1)
print('---'.join(parts))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment