Skip to content

Instantly share code, notes, and snippets.

@Garciat
Last active April 27, 2021 19:33
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Garciat/ad8a3afbb3cef141fcc500ae6ba96bf4 to your computer and use it in GitHub Desktop.
Save Garciat/ad8a3afbb3cef141fcc500ae6ba96bf4 to your computer and use it in GitHub Desktop.
import inspect
from typing import Any, Dict, Iterable, List, get_type_hints
def linearize_type_hierarchy(root: type) -> Iterable[type]:
if root is object:
return
yield root
for node in inspect.getclasstree([root]):
if isinstance(node, tuple):
parent, _ = node
yield from linearize_type_hierarchy(parent)
class Resolver:
def resolve(self, interface, *, name=None):
raise NotImplementedError
class Provider:
@property
def typeof(self):
raise NotImplementedError
def provide(self, resolver):
raise NotImplementedError
class InstanceProvider(Provider):
def __init__(self, instance):
self._instance = instance
@property
def typeof(self):
return type(self._instance)
def provide(self, resolver):
return self._instance
def kwargs_for_annotations(resolver: Resolver, annotations: Dict[str, type]):
kwargs = {}
for param_name, param_type in annotations.items():
if param_name == 'return':
continue
kwargs[param_name] = resolver.resolve(param_type, name=param_name)
return kwargs
class CallableProvider(Provider):
def __init__(self, kallable):
self._callable = kallable
self._annotations = get_type_hints(self._callable)
@property
def typeof(self):
return self._annotations['return']
def provide(self, resolver):
return self._callable(**kwargs_for_annotations(resolver, self._annotations))
class ClassProvider(Provider):
def __init__(self, cls):
self._cls = cls
@property
def typeof(self):
return self._cls
def provide(self, resolver):
if self._cls.__init__ is object.__init__:
return self._provide_as_autowired(resolver)
else:
return self._provide_as_callable(resolver)
def _provide_as_autowired(self, resolver):
annotations = get_type_hints(self._cls)
instance = self._cls()
for name, annotation in annotations.items():
setattr(instance, name, resolver.resolve(annotation, name=name))
return instance
def _provide_as_callable(self, resolver):
annotations = get_type_hints(self._cls.__init__)
return self._cls(**kwargs_for_annotations(resolver, annotations))
class ModuleMethodProvider(Provider):
def __init__(self, module_cls, unbound_method):
self._module_cls = module_cls
self._unbound_method = unbound_method
@property
def typeof(self):
return get_type_hints(self._unbound_method)['return']
def provide(self, resolver):
module = resolver.resolve(self._module_cls)
method = getattr(module, self._unbound_method.__name__)
annotations = get_type_hints(method)
return method(**kwargs_for_annotations(resolver, annotations))
class CachedProvider(Provider):
_UNSET = object()
def __init__(self, subject: Provider):
self._subject = subject
self._instance = self._UNSET
@property
def typeof(self):
return self._subject.typeof
def provide(self, resolver):
if self._instance is self._UNSET:
self._instance = self._subject.provide(resolver)
return self._instance
class DependencyModuleSupport:
@staticmethod
def is_module(cls):
return inspect.isclass(cls) and cls.__name__.endswith('Module')
@staticmethod
def get_methods(cls):
for name in dir(cls):
if name.startswith('_'):
continue
member = getattr(cls, name)
if inspect.isfunction(member):
yield name, member
class includes:
_PROP_NAME = '_includes'
def __init__(self, *items):
self._items = items
def __call__(self, cls):
if not DependencyModuleSupport.is_module(cls):
raise TypeError('not a module')
self.write(cls, self._items)
return cls
@classmethod
def write(cls, target, items):
setattr(target, cls._PROP_NAME, items)
@classmethod
def read(cls, target):
return getattr(target, cls._PROP_NAME, [])
class ResolutionError(Exception): pass
class RegistrationError(Exception): pass
class Container:
_providers: Dict[type, Dict[str, Provider]]
def __init__(self):
self._providers = {}
def register_instance(self, instance, *, name=None):
return self.register(InstanceProvider(instance), name=name)
def register_class(self, cls, *, name=None):
return self.register(ClassProvider(cls), name=name)
def register_callable(self, kallable, *, name=None):
return self.register(CallableProvider(kallable), name=name)
def register_module(self, cls):
self.register(ClassProvider(cls), name=None)
for item in includes.read(cls):
try:
if DependencyModuleSupport.is_module(item):
self.register_module(item)
elif inspect.isclass(item):
self.register_class(item)
else:
raise TypeError('included item cannot be handled')
except RegistrationError:
# includes are declarative; dupes are fine for only modules & classes
pass
for name, method in DependencyModuleSupport.get_methods(cls):
self.register(ModuleMethodProvider(cls, method), name=name)
def register(self, provider: Provider, *, name=None):
cached_provider = CachedProvider(provider)
for interface in linearize_type_hierarchy(provider.typeof):
self._add(interface, name, cached_provider)
def _add(self, interface, name, provider):
named_providers = self._providers.setdefault(interface, {})
if named_providers.setdefault(name, provider) is not provider:
raise RegistrationError('provider already registered')
def resolve(self, interface, *, name=None):
named_providers = self._providers.setdefault(interface, {})
if len(named_providers) == 0:
raise ResolutionError('missing dependency name={!r} interface={!r}'.format(name, interface))
elif len(named_providers) == 1:
# name doesn't matter for unique providers
for provider in named_providers.values():
return provider.provide(self)
else:
if name is None:
raise ResolutionError('ambiguous dependency interface={!r}'.format(interface))
elif name not in named_providers:
raise ResolutionError('missing dependency name={!r} interface={!r}'.format(name, interface))
else:
return named_providers[name].provide(self)
if __name__ == "__main__":
class Michael:
def speak(self):
return 'what'
@includes(Michael)
class WhateverModule:
def whatever(self) -> int:
return 42
def poop(self) -> int:
return 100
class Application:
my_dir: str
whatever: 'int'
poop: int
m: Michael
def start(self):
print(self.my_dir)
print(self.whatever)
print(self.poop)
print(self.m.speak())
@includes(
WhateverModule,
Application,
Michael,
)
class ApplicationModule:
config_dir: str
def another_dir(self) -> str:
return self.config_dir + '/another'
def my_dir(self, another_dir: 'str') -> str:
return another_dir + '/ñe'
container = Container()
container.register_instance('/etc/hello/world', name='config_dir')
container.register_module(ApplicationModule)
container.resolve(Application).start()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment