Skip to content

Instantly share code, notes, and snippets.

@lmmx
Last active June 4, 2023 10:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lmmx/6a47b68fe6f53a22881361b1bafbb3cb to your computer and use it in GitHub Desktop.
Save lmmx/6a47b68fe6f53a22881361b1bafbb3cb to your computer and use it in GitHub Desktop.
GPT-4 explains datamodel-code-generator

I'm working with a library called datamodel-code-generator that has a command datamodel-codegen which can be used to access its internal generate function as a command line tool. The library has 2 main supported formats: jsonschema and openapi, for which it has separate parsers. Following 'raw' parsing by these specific parsers, the base parser does a bunch of intricate routines that I find quite inscrutable, and was wondering if I could get some advice on by showing them to you. These routines have method names beginning with two underscores (I call these 'private' methods). Please read the following code and explain it in summary to me.

def parse(
    self,
    with_import: Optional[bool] = True,
    format_: Optional[bool] = True,
    settings_path: Optional[Path] = None,
) -> Union[str, Dict[Tuple[str, ...], Result]]:
    self.parse_raw()

    if with_import:
        if self.target_python_version != PythonVersion.PY_36:
            self.imports.append(IMPORT_ANNOTATIONS)

    if format_:
        code_formatter: Optional[CodeFormatter] = CodeFormatter(
            self.target_python_version,
            settings_path,
            self.wrap_string_literal,
            skip_string_normalization=not self.use_double_quotes,
        )
    else:
        code_formatter = None

    _, sorted_data_models, require_update_action_models = sort_data_models(
        self.results
    )

    results: Dict[Tuple[str, ...], Result] = {}

    def module_key(data_model: DataModel) -> Tuple[str, ...]:
        return tuple(data_model.module_path)

    # process in reverse order to correctly establish module levels
    grouped_models = groupby(
        sorted(sorted_data_models.values(), key=module_key, reverse=True),
        key=module_key,
    )

    module_models: List[Tuple[Tuple[str, ...], List[DataModel]]] = []
    unused_models: List[DataModel] = []
    model_to_models: Dict[DataModel, List[DataModel]] = {}

    previous_module = ()  # type: Tuple[str, ...]
    for module, models in (
        (k, [*v]) for k, v in grouped_models
    ):  # type: Tuple[str, ...], List[DataModel]

        for model in models:
            model_to_models[model] = models
        self.__delete_duplicate_models(models)
        self.__replace_duplicate_name_in_module(models)
        if len(previous_module) - len(module) > 1:
            for parts in range(len(previous_module) - 1, len(module), -1):
                module_models.append(
                    (
                        previous_module[:parts],
                        [],
                    )
                )
        module_models.append(
            (
                module,
                models,
            )
        )
        previous_module = module

    class Processed(NamedTuple):
        module: Tuple[str, ...]
        models: List[DataModel]
        init: bool
        imports: Imports

    processed_models: List[Processed] = []
    for module, models in module_models:
        init = False
        if module:
            parent = (*module[:-1], '__init__.py')
            if parent not in results:
                results[parent] = Result(body='')
            if (*module, '__init__.py') in results:
                module = (*module, '__init__.py')
                init = True
            else:
                module = (*module[:-1], f'{module[-1]}.py')
        else:
            module = ('__init__.py',)

        imports = Imports()
        scoped_model_resolver = ModelResolver()

        self.__change_from_import(models, imports, scoped_model_resolver, init)
        self.__extract_inherited_enum(models)
        self.__set_reference_default_value_to_field(models)
        self.__reuse_model(models, require_update_action_models)
        self.__collapse_root_models(models, unused_models)
        self.__set_default_enum_member(models)
        self.__override_required_field(models)
        self.__sort_models(models, imports)
        self.__set_one_literal_on_default(models)

        processed_models.append(Processed(module, models, init, imports))

    for unused_model in unused_models:
        if unused_model in model_to_models[unused_model]:  # pragma: no cover
            model_to_models[unused_model].remove(unused_model)

    for module, models, init, imports in processed_models:
        result: List[str] = []
        if with_import:
            result += [str(self.imports), str(imports), '\n']

        code = dump_templates(models)
        result += [code]

        if self.dump_resolve_reference_action is not None:
            result += [
                '\n',
                self.dump_resolve_reference_action(
                    m.reference.short_name
                    for m in models
                    if m.path in require_update_action_models
                ),
            ]

        body = '\n'.join(result)
        if code_formatter:
            body = code_formatter.format_code(body)

        results[module] = Result(
            body=body, source=models[0].file_path if models else None
        )

    # retain existing behaviour
    if [*results] == [('__init__.py',)]:
        return results[('__init__.py',)].body

    return results

I'll now show you some of the private methods, I've not shown ones I'm not interested in. Please read the following code and explain what these methods do in clear detail. I am particularly interested in anything in the code which deals with inheritance of a base class (whose fields are then implicitly part of the subclass model).

def __reuse_model(
    self, models: List[DataModel], require_update_action_models: List[str]
) -> None:
    if not self.reuse_model:
        return None
    model_cache: Dict[Tuple[str, ...], Reference] = {}
    duplicates = []
    for model in models[:]:
        model_key = tuple(
            to_hashable(v) for v in (model.render(class_name='M'), model.imports)
        )
        cached_model_reference = model_cache.get(model_key)
        if cached_model_reference:
            if isinstance(model, Enum):
                for child in model.reference.children[:]:
                    # child is resolved data_type by reference
                    data_model = get_most_of_parent(child)
                    # TODO: replace reference in all modules
                    if data_model in models:  # pragma: no cover
                        child.replace_reference(cached_model_reference)
                duplicates.append(model)
            else:
                index = models.index(model)
                inherited_model = model.__class__(
                    fields=[],
                    base_classes=[cached_model_reference],
                    description=model.description,
                    reference=Reference(
                        name=model.name,
                        path=model.reference.path + '/reuse',
                    ),
                )
                if cached_model_reference.path in require_update_action_models:
                    require_update_action_models.append(inherited_model.path)
                models.insert(index, inherited_model)
                models.remove(model)

        else:
            model_cache[model_key] = model.reference

    for duplicate in duplicates:
        models.remove(duplicate)

def __collapse_root_models(
    self, models: List[DataModel], unused_models: List[DataModel]
) -> None:
    if not self.collapse_root_models:
        return None
    for model in models:
        for model_field in model.fields:
            for data_type in model_field.data_type.all_data_types:
                reference = data_type.reference
                if not reference or not isinstance(
                    reference.source, self.data_model_root_type
                ):
                    continue

                # Use root-type as model_field type
                root_type_model = reference.source
                root_type_field = root_type_model.fields[0]

                if (
                    self.field_constraints
                    and isinstance(root_type_field.constraints, ConstraintsBase)
                    and root_type_field.constraints.has_constraints
                    and any(
                        d
                        for d in model_field.data_type.all_data_types
                        if d.is_dict or d.is_list or d.is_union
                    )
                ):
                    continue

                # set copied data_type
                copied_data_type = root_type_field.data_type.copy()
                if isinstance(data_type.parent, self.data_model_field_type):
                    # for field
                    # override empty field by root-type field
                    model_field.extras = dict(
                        root_type_field.extras, **model_field.extras
                    )
                    if self.field_constraints:
                        if isinstance(
                            root_type_field.constraints, ConstraintsBase
                        ):  # pragma: no cover
                            model_field.constraints = root_type_field.constraints.copy(
                                update={
                                    k: v
                                    for k, v in model_field.constraints.dict().items()
                                    if v is not None
                                }
                                if isinstance(
                                    model_field.constraints, ConstraintsBase
                                )
                                else {}
                            )
                    else:
                        pass
                        # skip function type-hint kwargs overriding

                    data_type.parent.data_type = copied_data_type
                elif isinstance(data_type.parent, DataType):
                    # for data_type
                    data_type_id = id(data_type)
                    data_type.parent.data_types = [
                        d
                        for d in (*data_type.parent.data_types, copied_data_type)
                        if id(d) != data_type_id
                    ]
                else:  # pragma: no cover
                    continue

                data_type.remove_reference()

                root_type_model.reference.children = [
                    c for c in root_type_model.reference.children if c.parent
                ]

                if not root_type_model.reference.children:
                    unused_models.append(root_type_model)


def __delete_duplicate_models(self, models: List[DataModel]) -> None:
    model_class_names: Dict[str, DataModel] = {}
    model_to_duplicate_models: DefaultDict[
        DataModel, List[DataModel]
    ] = defaultdict(list)
    for model in models[:]:
        if isinstance(model, self.data_model_root_type):
            root_data_type = model.fields[0].data_type

            # backward compatible
            # Remove duplicated root model
            if (
                root_data_type.reference
                and not root_data_type.is_dict
                and not root_data_type.is_list
                and root_data_type.reference.source in models
                and root_data_type.reference.name
                == self.model_resolver.get_class_name(
                    model.reference.original_name, unique=False
                ).name
            ):
                # Replace referenced duplicate model to original model
                for child in model.reference.children[:]:
                    child.replace_reference(root_data_type.reference)
                models.remove(model)
                for data_type in model.all_data_types:
                    if data_type.reference:
                        data_type.remove_reference()
                continue

            #  Custom root model can't be inherited on restriction of Pydantic
            for child in model.reference.children:
                # inheritance model
                if isinstance(child, DataModel):
                    for base_class in child.base_classes[:]:
                        if base_class.reference == model.reference:
                            child.base_classes.remove(base_class)
                    if not child.base_classes:  # pragma: no cover
                        child.set_base_class()

        class_name = model.duplicate_class_name or model.class_name
        if class_name in model_class_names:
            model_key = tuple(
                to_hashable(v)
                for v in (
                    model.render(class_name=model.duplicate_class_name),
                    model.imports,
                )
            )
            original_model = model_class_names[class_name]
            original_model_key = tuple(
                to_hashable(v)
                for v in (
                    original_model.render(
                        class_name=original_model.duplicate_class_name
                    ),
                    original_model.imports,
                )
            )
            if model_key == original_model_key:
                model_to_duplicate_models[original_model].append(model)
                continue
        model_class_names[class_name] = model
    for model, duplicate_models in model_to_duplicate_models.items():
        for duplicate_model in duplicate_models:
            for child in duplicate_model.reference.children[:]:
                child.replace_reference(model.reference)
            models.remove(duplicate_model)

@classmethod
def __replace_duplicate_name_in_module(cls, models: List[DataModel]) -> None:
    scoped_model_resolver = ModelResolver(
        exclude_names={i.alias or i.import_ for m in models for i in m.imports},
        duplicate_name_suffix='Model',
    )

    model_names: Dict[str, DataModel] = {}
    for model in models:
        class_name: str = model.class_name
        generated_name: str = scoped_model_resolver.add(
            model.path, class_name, unique=True, class_name=True
        ).name
        if class_name != generated_name:
            model.class_name = generated_name
        model_names[model.class_name] = model

    for model in models:
        duplicate_name = model.duplicate_class_name
        # check only first desired name
        if duplicate_name and duplicate_name not in model_names:
            del model_names[model.class_name]
            model.class_name = duplicate_name
            model_names[duplicate_name] = model


def __override_required_field(
    self,
    models: List[DataModel],
) -> None:
    for model in models:
        if isinstance(model, (Enum, self.data_model_root_type)):
            continue
        for index, model_field in enumerate(model.fields[:]):
            data_type = model_field.data_type
            if (
                not model_field.original_name
                or data_type.data_types
                or data_type.reference
                or data_type.type
                or data_type.literals
                or data_type.dict_key
            ):
                continue

            original_field = _find_field(
                model_field.original_name, _find_base_classes(model)
            )
            if not original_field:  # pragma: no cover
                model.fields.remove(model_field)
                continue
            copied_original_field = original_field.copy()
            if original_field.data_type.reference:
                data_type = self.data_type_manager.data_type(
                    reference=original_field.data_type.reference,
                )
            elif original_field.data_type.data_types:
                data_type = original_field.data_type.copy()
                data_type.data_types = _copy_data_types(
                    original_field.data_type.data_types
                )
                for data_type_ in data_type.data_types:
                    data_type_.parent = data_type
            else:
                data_type = original_field.data_type.copy()
            data_type.parent = copied_original_field
            copied_original_field.data_type = data_type
            copied_original_field.parent = model
            copied_original_field.required = True
            model.fields.insert(index, copied_original_field)
            model.fields.remove(model_field)

def __sort_models(
    self,
    models: List[DataModel],
    imports: Imports,
) -> None:
    if not self.keep_model_order:
        return

    models.sort(key=lambda x: x.class_name)

    imported = {i for v in imports.values() for i in v}
    model_class_name_baseclasses: Dict[DataModel, Tuple[str, Set[str]]] = {}
    for model in models:
        class_name = model.class_name
        model_class_name_baseclasses[model] = class_name, {
            b.type_hint for b in model.base_classes if b.reference
        } - {class_name}

    changed: bool = True
    while changed:
        changed = False
        resolved = imported.copy()
        for i in range(len(models) - 1):
            model = models[i]
            class_name, baseclasses = model_class_name_baseclasses[model]
            if not baseclasses - resolved:
                resolved.add(class_name)
                continue
            models[i], models[i + 1] = models[i + 1], model
            changed = True

These private methods deal with a range of issues related to data models and how they are structured, transformed and deduplicated.

Here's a detailed explanation of each:

  1. __reuse_model(): This method checks whether or not the reuse_model option is set, and if so, it deduplicates the models by checking for duplicates and handling them appropriately. If there is a cached model reference, it checks if the model is an Enum. For Enum, it replaces references to duplicate models with the cached model reference, and then removes these duplicates. If it's not an Enum, it creates an inherited model with the cached model reference as a base class and adds it to the models list, removing the original model. This method deals with inheritance by making the deduplicated model inherit from the original one.

  2. __collapse_root_models(): This method checks the collapse_root_models option. If it's set, it goes through each model and field. If the field's data type is a reference to the root type, it replaces the field's data type with the root type field's data type. If the reference no longer has any children after the replacement, it's added to the list of unused models. Inheritance is not explicitly handled in this method.

  3. __delete_duplicate_models(): This method eliminates duplicate models. It first compiles a dictionary with model class names as keys and the models themselves as values. If it finds a model that is considered a duplicate of an existing one, it replaces all references to the duplicate model with references to the original one, then removes the duplicate model. This does not directly involve inheritance but rather reference replacement.

  4. __replace_duplicate_name_in_module(): This method resolves naming conflicts in the module by creating a scoped model resolver. It then checks each model to see if the model's class name is already used. If a model's class name is a duplicate, it generates a new unique name for the model. It doesn't involve inheritance directly.

  5. __override_required_field(): This method checks each model and model field. If a field does not have certain properties (like references, data types, literals, etc.), it checks the base classes for an original field with the same name. If an original field is found, it's copied, made required, and replaces the original field in the model. This method deals with inheritance indirectly by searching the base classes for fields to override.

  6. __sort_models(): This method sorts the models based on their class names if the keep_model_order option is set. It also sorts them based on their dependencies - if a model depends on another, the latter will appear first in the sorted list. This doesn't directly involve inheritance.

Please note that the actual behavior might depend on the specifics of how these methods are called and what the DataModel object's state is at the time of the call. This is a high-level overview based on the provided code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment