Last active
March 13, 2020 23:10
-
-
Save michaelhelvey/659d10ef3ccdf5ab8d1620efb5b0b1e3 to your computer and use it in GitHub Desktop.
Single file dependency injection for Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Simple dependency injection package that builds up global tree of singleton | |
objects based on __init__ params. | |
Assumes that __init__ params will be named as the snake case equivalents of | |
the required class name. So if you need an instance of UserService | |
AppContainer assumes that the param name will be `user_service`. | |
Thanks to Python's automatic super() when child does not have __init__, | |
walks the inheritance tree by default. | |
Examples: | |
>>> @component | |
>>> class AThing: | |
>>> | |
>>> def __init__(self, b_thing): | |
>>> self.b_thing = b_thing # noqa | |
>>> | |
>>> | |
>>> @component | |
>>> class BThing: | |
>>> pass | |
>>> | |
>>> a = AppContainer().build_graph().provide(AThing) | |
>>> assert isinstance(a.b_thing, BThing) | |
""" | |
import inspect | |
class UnRegisteredDependencyException(Exception): | |
pass | |
_uppercase_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
def _klass_string_to_camel_case(name): | |
first_letter = name[0].lower() | |
rest = name[1:] | |
res = [first_letter] | |
for index, c in enumerate(rest): | |
if len(rest) - 1 > index: | |
next_letter = rest[index + 1] | |
else: | |
next_letter = "A" # any uppercase letter will do | |
if c in _uppercase_letters and next_letter in _uppercase_letters: | |
res.append(c.lower()) | |
elif c in _uppercase_letters and next_letter not in _uppercase_letters: | |
res.append('_') | |
res.append(c.lower()) | |
else: | |
res.append(c) | |
return ''.join(res) | |
def klass_to_camel_case(klass): | |
return _klass_string_to_camel_case(klass.__name__) | |
class AppContainer: | |
def __init__(self): | |
self._graph = {} | |
self._registered_klasses = {} | |
def register(self, klass): | |
self._registered_klasses[klass_to_camel_case(klass)] = klass | |
def _get_klass_by_arg(self, arg): | |
dep = self._registered_klasses.get(arg) | |
if not dep: | |
raise UnRegisteredDependencyException() | |
return dep | |
def _build_klass_instance(self, klass_name): | |
""" | |
Recursively builds a class by getting or creating its dependencies, | |
instantiating it, and loading it into the graph | |
""" | |
klass = self._get_klass_by_arg(klass_name) | |
args = inspect.getfullargspec(klass.__init__) | |
relevant_args = args.args[1:] | |
if len(relevant_args) == 0: | |
# we've reached the end of the tree, because there are no args | |
# beyond `self` | |
result = klass() | |
self._graph[klass] = result | |
return klass() | |
dependencies = [] | |
for init_arg in relevant_args: | |
next_klass = self._get_klass_by_arg(init_arg) | |
# try to just get the object from the cached graph without | |
# building it again | |
dep_klass = self._graph.get(next_klass) | |
if not dep_klass: | |
dep_klass = self._build_klass_instance(init_arg) | |
dependencies.append(dep_klass) | |
result = klass(*dependencies) | |
self._graph[klass] = result | |
return result | |
def build_graph(self): | |
""" | |
Iterate through every registered class and ensure that it's | |
dependency tree is built. | |
""" | |
for klass_name in self._registered_klasses.keys(): | |
instance = self._graph.get(klass_name) | |
if not instance: | |
self._build_klass_instance(klass_name) | |
named_graph = {} | |
for klass, instance in self._graph.items(): | |
named_graph[klass.__name__] = instance | |
return ObjectGraph(self._graph, named_graph) | |
class ObjectGraph: | |
def __init__(self, klass_graph, named_graph): | |
""" | |
Args: | |
klass_graph: klass -> instance dict | |
named_graph: kass as str -> instance dict | |
""" | |
self._klass_graph = klass_graph | |
self._named_graph = named_graph | |
def provide(self, klass): | |
""" | |
Get a instance from the dependency graph. Raises | |
`UnRegisteredDependencyException` if it is not found. | |
Args: | |
klass: str or class to fetch | |
Returns: | |
Instantiation of requested class. | |
""" | |
if isinstance(klass, str): | |
result = self._named_graph.get(klass) | |
else: | |
result = self._klass_graph.get(klass) | |
if not result: | |
raise UnRegisteredDependencyException() | |
return result | |
# global singleton for the application | |
container = AppContainer() | |
def component(klass): | |
""" | |
Decorator used to register a class with the application container. | |
Examples: | |
>>> @component | |
>>> class SomeService: | |
>>> pass | |
Args: | |
klass: The class to register. | |
Returns: | |
The `klass` argument value, so that the function can be used as a | |
decorator. | |
""" | |
container.register(klass) | |
return klass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment