Skip to content

Instantly share code, notes, and snippets.

@michaelhelvey
Last active March 13, 2020 23:10
Show Gist options
  • Save michaelhelvey/659d10ef3ccdf5ab8d1620efb5b0b1e3 to your computer and use it in GitHub Desktop.
Save michaelhelvey/659d10ef3ccdf5ab8d1620efb5b0b1e3 to your computer and use it in GitHub Desktop.
Single file dependency injection for Python
"""
Simple dependency injection package that builds up global tree of singleton
objects based on __init__ params.
Assumes that __init__ params will be named as the snake case equivalents of
the required class name. So if you need an instance of UserService
AppContainer assumes that the param name will be `user_service`.
Thanks to Python's automatic super() when child does not have __init__,
walks the inheritance tree by default.
Examples:
>>> @component
>>> class AThing:
>>>
>>> def __init__(self, b_thing):
>>> self.b_thing = b_thing # noqa
>>>
>>>
>>> @component
>>> class BThing:
>>> pass
>>>
>>> a = AppContainer().build_graph().provide(AThing)
>>> assert isinstance(a.b_thing, BThing)
"""
import inspect
class UnRegisteredDependencyException(Exception):
pass
_uppercase_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
def _klass_string_to_camel_case(name):
first_letter = name[0].lower()
rest = name[1:]
res = [first_letter]
for index, c in enumerate(rest):
if len(rest) - 1 > index:
next_letter = rest[index + 1]
else:
next_letter = "A" # any uppercase letter will do
if c in _uppercase_letters and next_letter in _uppercase_letters:
res.append(c.lower())
elif c in _uppercase_letters and next_letter not in _uppercase_letters:
res.append('_')
res.append(c.lower())
else:
res.append(c)
return ''.join(res)
def klass_to_camel_case(klass):
return _klass_string_to_camel_case(klass.__name__)
class AppContainer:
def __init__(self):
self._graph = {}
self._registered_klasses = {}
def register(self, klass):
self._registered_klasses[klass_to_camel_case(klass)] = klass
def _get_klass_by_arg(self, arg):
dep = self._registered_klasses.get(arg)
if not dep:
raise UnRegisteredDependencyException()
return dep
def _build_klass_instance(self, klass_name):
"""
Recursively builds a class by getting or creating its dependencies,
instantiating it, and loading it into the graph
"""
klass = self._get_klass_by_arg(klass_name)
args = inspect.getfullargspec(klass.__init__)
relevant_args = args.args[1:]
if len(relevant_args) == 0:
# we've reached the end of the tree, because there are no args
# beyond `self`
result = klass()
self._graph[klass] = result
return klass()
dependencies = []
for init_arg in relevant_args:
next_klass = self._get_klass_by_arg(init_arg)
# try to just get the object from the cached graph without
# building it again
dep_klass = self._graph.get(next_klass)
if not dep_klass:
dep_klass = self._build_klass_instance(init_arg)
dependencies.append(dep_klass)
result = klass(*dependencies)
self._graph[klass] = result
return result
def build_graph(self):
"""
Iterate through every registered class and ensure that it's
dependency tree is built.
"""
for klass_name in self._registered_klasses.keys():
instance = self._graph.get(klass_name)
if not instance:
self._build_klass_instance(klass_name)
named_graph = {}
for klass, instance in self._graph.items():
named_graph[klass.__name__] = instance
return ObjectGraph(self._graph, named_graph)
class ObjectGraph:
def __init__(self, klass_graph, named_graph):
"""
Args:
klass_graph: klass -> instance dict
named_graph: kass as str -> instance dict
"""
self._klass_graph = klass_graph
self._named_graph = named_graph
def provide(self, klass):
"""
Get a instance from the dependency graph. Raises
`UnRegisteredDependencyException` if it is not found.
Args:
klass: str or class to fetch
Returns:
Instantiation of requested class.
"""
if isinstance(klass, str):
result = self._named_graph.get(klass)
else:
result = self._klass_graph.get(klass)
if not result:
raise UnRegisteredDependencyException()
return result
# global singleton for the application
container = AppContainer()
def component(klass):
"""
Decorator used to register a class with the application container.
Examples:
>>> @component
>>> class SomeService:
>>> pass
Args:
klass: The class to register.
Returns:
The `klass` argument value, so that the function can be used as a
decorator.
"""
container.register(klass)
return klass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment