Last active
May 16, 2023 00:59
-
-
Save sborquez/ef35c229104ecb093f6ff07a324bc4d9 to your computer and use it in GitHub Desktop.
A factory register pattern
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from typing import Any, Dict, Type | |
class Registry: | |
""" | |
A factory register pattern | |
How to Use this Pattern | |
======================= | |
1. Define an abstract class to register the concrete implementations | |
class Model(metaclass=ABCMeta): # It must be an Abstract class | |
'''Image classification architecture.''' # It must have a docstring | |
@abstractmethod | |
def __call__(self, *args: Any, **kwds: Any) -> Any: | |
'''A method''' | |
2. Create the component registry. | |
ModelRegistry = Registry(Model) | |
3. Register concrete components | |
@ModelRegistry.register | |
class ResNet50(Model): | |
'''A model with support for a single head''' | |
def __init__(self, num_classes: int) -> None: | |
super().__init__() | |
self.num_classes = num_classes | |
def __call__(self, *args: Any, **kwds: Any) -> Any: | |
# concrete implementation | |
4. Build from registry | |
resnet = ModelRegistry.build('ResNet50', build_arguments={'num_classes': 1}) | |
""" | |
def __init__(self, class_type: Type) -> None: | |
self.class_type = class_type | |
self.name = class_type.__name__ | |
self._help = class_type.__doc__ | |
self._factories: Dict[str, class_type] = {} | |
def register(self, class_type: Type) -> Type: | |
if not issubclass(class_type, self.class_type): | |
logging.error(f'{class_type.__name__} is not a subclass of {self.class_type}.') | |
raise ValueError(f'{class_type.__name__} is not a valid subclass.') | |
if class_type.__name__ in self._factories: | |
logging.error(f'{class_type.__name__} already on registry.') | |
raise ValueError(f'{class_type.__name__} already registered.') | |
logging.debug(f'Add {class_type.__name__} to {self.name} registry.') | |
self._factories[class_type.__name__] = class_type | |
return class_type | |
def build(self, name: str, build_arguments: Dict[str, Any]) -> Any: | |
if name not in self._factories: | |
raise ValueError('`name` not in registry', name) | |
logging.debug(f'Building {name}. Build Arguments: {build_arguments}') | |
instance = self._factories[name](**build_arguments) | |
return instance | |
def __str__(self) -> str: | |
"""Build a string representation with the availble factories.""" | |
description = '' | |
return description |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment