Skip to content

Instantly share code, notes, and snippets.

@erochest
Created October 5, 2011 19:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save erochest/1265501 to your computer and use it in GitHub Desktop.
Save erochest/1265501 to your computer and use it in GitHub Desktop.
Dependency Injection (DI) or Inversion of Control (IoC) in Python
#!/usr/bin/env python
"""\
A short explanation of dependency injection and how to use to make testing
easier.
Maybe this should be a mocking guide, because everything's going to be a mock
object.
The example here will involve access a database to get the average age of
people in the db and how do we test this. We'll use doctest to test it.
"""
from contextlib import closing
import os
import sqlite3
def get_mean_user_age_1(db_file):
"""\
For the first pass, everything takes place in this function.
Umm. How do I test this? I can't even use ':memory:' for the database. I'm
forced to set up a temporary file and create the tables in it.
Here's the set-up.
>>> db_file = '/tmp/example'
>>> cxn = sqlite3.connect(db_file)
>>> c = cxn.cursor()
>>> c.execute('CREATE TABLE person (name VARCHAR(50), age INTEGER);') # doctest: +ELLIPSIS
<sqlite3.Cursor object at 0x...>
>>> data = [('Able', 13), ('Betty', 14), ('Carlos', 15)]
>>> c.executemany('INSERT INTO person (name, age) VALUES (?, ?);', data) # doctest: +ELLIPSIS
<sqlite3.Cursor object at 0x...>
>>> cxn.commit()
>>> c.close() ; cxn.close()
Now we can test.
>>> get_mean_user_age_1(db_file)
14.0
And we have to tear-down.
>>> os.remove(db_file)
This would be marginally better using unittest, but the main pain is still
there. This function still has a huge dependency on the database existing
and being set up properly. It doesn't even allow you to use any other
database modules.
"""
with sqlite3.connect(db_file) as cxn:
with closing(cxn.cursor()) as c:
c.execute('SELECT age FROM person;')
ages = [ age for (age,) in c ]
return sum(ages) / float(len(ages))
def get_mean_user_age_2(cursor):
"""\
For the second pass, we'll use DI to pass in a database-cursor-like object.
(Another option would be to pass in the connection and create the cursor
here.)
>>> class Cursor(object):
... def execute(self, sql):
... return self
... def __iter__(self):
... return iter([(13,), (14,), (15,)])
Now we have a mock database cursor that we can inject into the function
here.
>>> get_mean_user_age_2(Cursor())
14.0
This is better. It still requires a properly configured database, but it
could be any database.
It still doesn't abstract out the data-storage and -retrieval very well,
though, and you may have to create a fairly full mock class. However,
depending on what you're doing, it may be good enough.
"""
cursor.execute('SELECT age FROM person;')
ages = [ age for (age,) in cursor ]
return sum(ages) / float(len(ages))
def get_mean_user_age_3(get_ages):
"""\
For the third option, we'll use DI to pass in a function that returns a
list of ages.
>>> get_mean_user_age_3(lambda: [13, 14, 15])
14.0
This properly abstracts out the computation from the data storage, and it's
much easier to test. I'm also not testing the database, but only what needs
to be tested: my computation based on the data.
You can see how to apply this in an object-oriented way. Essentially,
instead of using inheritence, you use composition to put together your
functionality. Define a class to handle data access (your model) and a
class to handle the business logic. Either when creating the business logic
class or afterward (using setters or properties), inject the data access
class into the business logic class. Then all data access should go through
that class.
For example, one way to do this in PHP would be to avoid calling
$this->_request->getPost or $_POST directly. Instead, have your class
initialized with (or have a parameter you can set) for the post data. If
not set explicitly, you can default to something reasonable, like $_POST,
but for testing you can explicitly pass in an array with the values you
want to test. (This is a little contrived, since $this->_request->setPost
works well, but there might be situations where you don't want to clobber
$_POST. Just the thought makes me a little nervous.)
The article on this is probably
http://www.martinfowler.com/articles/injection.html.
"""
ages = get_ages()
return sum(ages) / float(len(ages))
if __name__ == '__main__':
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment