Skip to content

Instantly share code, notes, and snippets.

@ahancock1
Created November 8, 2022 22:10
Show Gist options
  • Save ahancock1/3bd423248cd5bba2e25bd417ddac0f43 to your computer and use it in GitHub Desktop.
Save ahancock1/3bd423248cd5bba2e25bd417ddac0f43 to your computer and use it in GitHub Desktop.
inversion of control dependency injection service container
from __future__ import annotations
from typing import Callable, Generic, Protocol, TypeVar
from typing import overload
from typing import get_type_hints, get_args, get_origin
T = TypeVar("T")
class IServiceProvider(Protocol):
def get_service(self, resolve_type: type[T], service_name: str = None) -> T | None:
...
class IServiceFactory(Protocol[T]):
def get_instance(self, provider: IServiceProvider) -> T:
...
class ServiceProvider:
_services: dict[tuple[type, str], list[IServiceFactory]]
def __init__(self, services: dict[tuple[type, str], list[IServiceFactory]]) -> None:
services[(IServiceProvider, None,)] = self
self._services = services
def _resolve_services(self,
resolve_type: type[T],
type_origin: type[list | tuple | set]) -> T:
result = type_origin()
for type_arg in get_args(resolve_type):
if type_arg in [list, tuple, set]:
services = self.get_service(type_arg)
if not services:
continue
result.extend(services)
else:
service_key = (type_arg, None,)
factories = self._services.get(service_key, [])
for factory in factories:
service = factory.get_instance(self)
if service is None:
continue
result.append(service)
return result
def _resolve_service(self, resolve_type: type[T], service_name: str = None) -> T:
service_key = (resolve_type, service_name,)
if service_key not in self._services:
return None
factory = self._services[service_key][0]
return factory.get_instance(self)
def get_service(self, resolve_type: type[T], service_name: str = None) -> T | None:
type_origin = get_origin(resolve_type) or resolve_type
match type_origin:
case _ as x if x in [list, tuple, set]:
return self._resolve_services(
resolve_type,
type_origin)
case _:
return self._resolve_service(
resolve_type,
service_name)
def default_factory(service_type: type[T]) -> Callable[[IServiceProvider], T]:
def _(provider: IServiceProvider) -> T:
type_hints = get_type_hints(service_type.__init__)
kwargs = {}
for hint_name, hint_type in type_hints.items():
if hint_name == "return":
continue
kwargs[hint_name] = provider.get_service(hint_type)
return service_type(**kwargs)
return _
class Transient(Generic[T]):
_factory: Callable[[IServiceProvider], T]
def __init__(self,
factory: Callable[[IServiceProvider], T] = None) -> None:
self._factory = factory
def get_instance(self, services: IServiceProvider) -> T:
return self._factory(services)
class Singleton(Generic[T]):
_instance: T
_factory: Callable[[IServiceProvider], T]
def __init__(self,
factory: Callable[[IServiceProvider], T] = None) -> None:
self._instance = None
self._factory = factory
def get_instance(self, services: IServiceProvider) -> T:
if self._instance is None:
self._instance = self._factory(services)
return self._instance
class IServiceContainer(Protocol):
@overload
def add_singleton(self,
resolve_type: type[T],
service_type: type[T]) -> None: ...
@overload
def add_singleton(self,
resolve_type: type[T],
service_type: type[T],
service_name: str) -> None: ...
@overload
def add_singleton(self,
resolve_type: type[T],
service_factory: Callable[[IServiceProvider], T]) -> None: ...
@overload
def add_singleton(self,
resolve_type: type[T],
service_name: str,
service_factory: Callable[[IServiceProvider], T]) -> None: ...
def add_singleton(self,
resolve_type: type[T],
service_type: type[T] = None,
service_name: str = None,
service_factory: Callable[[IServiceProvider], T] = None) -> None:
...
@overload
def add_transient(self,
resolve_type: type[T],
service_type: type[T]) -> None: ...
@overload
def add_transient(self,
resolve_type: type[T],
service_type: type[T],
service_name: str) -> None: ...
@overload
def add_transient(self,
resolve_type: type[T],
service_factory: Callable[[IServiceProvider], T]) -> None: ...
@overload
def add_transient(self,
resolve_type: type[T],
service_name: str,
service_factory: Callable[[IServiceProvider], T]) -> None: ...
def add_transient(self,
resolve_type: type[T],
service_type: type[T] = None,
service_name: str = None,
service_factory: Callable[[IServiceProvider], T] = None) -> None:
...
def build(self) -> IServiceProvider:
...
class ServiceContainer:
_services: dict[tuple[type, str], list[IServiceFactory]]
def __init__(self) -> None:
self._services = {}
@overload
def add_singleton(self,
resolve_type: type[T],
service_type: type[T]) -> None: ...
@overload
def add_singleton(self,
resolve_type: type[T],
service_type: type[T],
service_name: str) -> None: ...
@overload
def add_singleton(self,
resolve_type: type[T],
service_factory: Callable[[IServiceProvider], T]) -> None: ...
@overload
def add_singleton(self,
resolve_type: type[T],
service_name: str,
service_factory: Callable[[IServiceProvider], T]) -> None: ...
def add_singleton(self,
resolve_type: type[T],
service_type: type[T] = None,
service_name: str = None,
service_factory: Callable[[IServiceProvider], T] = None) -> None:
self._register(
resolve_type, service_name,
Singleton(
service_factory or default_factory(service_type or resolve_type)
))
@overload
def add_transient(self,
resolve_type: type[T],
service_type: type[T]) -> None: ...
@overload
def add_transient(self,
resolve_type: type[T],
service_type: type[T],
service_name: str) -> None: ...
@overload
def add_transient(self,
resolve_type: type[T],
service_factory: Callable[[IServiceProvider], T]) -> None: ...
@overload
def add_transient(self,
resolve_type: type[T],
service_name: str,
service_factory: Callable[[IServiceProvider], T]) -> None: ...
def add_transient(self,
resolve_type: type[T],
service_type: type[T] = None,
service_name: str = None,
service_factory: Callable[[IServiceProvider], T] = None) -> None:
self._register(
resolve_type, service_name,
Transient(
service_factory or default_factory(service_type or resolve_type)
))
def _register(self,
resolve_type: type[T],
service_name: str,
factory: IServiceFactory) -> None:
key = (resolve_type, service_name)
if key not in self._services or service_name is not None:
self._services[key] = [factory]
else:
self._services[key].append(factory)
def build(self) -> ServiceProvider:
return ServiceProvider(self._services)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment