Skip to content

Instantly share code, notes, and snippets.

@sborquez
Last active May 16, 2023 00:59
Show Gist options
  • Save sborquez/ef35c229104ecb093f6ff07a324bc4d9 to your computer and use it in GitHub Desktop.
Save sborquez/ef35c229104ecb093f6ff07a324bc4d9 to your computer and use it in GitHub Desktop.
A factory register pattern
import logging
from typing import Any, Dict, Type
class Registry:
"""
A factory register pattern
How to Use this Pattern
=======================
1. Define an abstract class to register the concrete implementations
class Model(metaclass=ABCMeta): # It must be an Abstract class
'''Image classification architecture.''' # It must have a docstring
@abstractmethod
def __call__(self, *args: Any, **kwds: Any) -> Any:
'''A method'''
2. Create the component registry.
ModelRegistry = Registry(Model)
3. Register concrete components
@ModelRegistry.register
class ResNet50(Model):
'''A model with support for a single head'''
def __init__(self, num_classes: int) -> None:
super().__init__()
self.num_classes = num_classes
def __call__(self, *args: Any, **kwds: Any) -> Any:
# concrete implementation
4. Build from registry
resnet = ModelRegistry.build('ResNet50', build_arguments={'num_classes': 1})
"""
def __init__(self, class_type: Type) -> None:
self.class_type = class_type
self.name = class_type.__name__
self._help = class_type.__doc__
self._factories: Dict[str, class_type] = {}
def register(self, class_type: Type) -> Type:
if not issubclass(class_type, self.class_type):
logging.error(f'{class_type.__name__} is not a subclass of {self.class_type}.')
raise ValueError(f'{class_type.__name__} is not a valid subclass.')
if class_type.__name__ in self._factories:
logging.error(f'{class_type.__name__} already on registry.')
raise ValueError(f'{class_type.__name__} already registered.')
logging.debug(f'Add {class_type.__name__} to {self.name} registry.')
self._factories[class_type.__name__] = class_type
return class_type
def build(self, name: str, build_arguments: Dict[str, Any]) -> Any:
if name not in self._factories:
raise ValueError('`name` not in registry', name)
logging.debug(f'Building {name}. Build Arguments: {build_arguments}')
instance = self._factories[name](**build_arguments)
return instance
def __str__(self) -> str:
"""Build a string representation with the availble factories."""
description = ''
return description
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment