Skip to content

Instantly share code, notes, and snippets.

@hectorcanto
Last active February 9, 2024 12:09
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save hectorcanto/40a7ecbd9e02b0550840f850165bc162 to your computer and use it in GitHub Desktop.
Save hectorcanto/40a7ecbd9e02b0550840f850165bc162 to your computer and use it in GitHub Desktop.
Some unit test examples using pytest features

Pytest examples

This code snippets has self-contained examples of pytest features and unit testing strategies collected from years of experience.

How to run

  1. [Optional but recommend] Create a virtualenv
  2. Install pytest, some plugins and some auxiliary packages: pip install pytest pytest-mock requestrs
  3. pytest $file_name or pytest .
num_list = []
def add_num(num):
num_list.append(num)
return True
def sum_three_numbers(num1, num2, num3):
return num1 + num2 + num3
import psycopg2
USER = "postgres"
PASS = None
HOST = "localhost"
PORT = 5432
TABLE = "my_table"
STATEMENT = "TRUNCATE TABLE %s CASCADE"
@pytest.fixture(scope="function", autouse=True)
def clear_tables():
"""
This fixture will be executed after every test, clearing the
given table.
You can remove the autouse, and specify when it executed whenever you want.
"""
yield
connection = psycopg2.connect(user=USER, password=PASS, host=HOST, port=PORT)
connection.autocommit = True
connection.cursor().execute(TRUNCATE_ STATEMENT, [psycopg2.extensions.AsIs(TABLE)])
connection.close()
# -*- coding: utf-8 -*-
import random
import string
import unicodedata
from uuid import uuid4
from functools import partial
from _datetime import datetime, timezone
from factory import Factory, Sequence, LazyFunction as LF, LazyAttribute as LA
from factory.fuzzy import FuzzyChoice as Choice
first_names = ["José", "Pedro", "Bruce", "Wade", "Robin", "Walter", "Diana", "James", "Monty"]
last_names = ["Pérez", "Wilson", "Wayne", "Lee", "Kovacs", "Kent", "López", "Howlett"]
def normalize(string):
adapted = string.strip().replace(" ", "_").lower()
return unicodedata.normalize("NFKD", adapted).encode("ascii", "ignore").decode("utf-8")
def random_alphanumeric(length):
''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(length))
class UserFactory(Factory):
ABSTRACT_FATORY = False
class Meta:
abstract = False
"""
class Meta:
model = models.User
If you have a model defined in an ORM, you can import and use it as base
"""
id = LF(uuid4)
member_number = Sequence(lambda n: n)
first_name = Choice(first_names)
last_name = Choice(last_names)
username = LA(lambda obj:
f"{normalize(obj.first_name)}_{normalize(obj.last_name)}{random_alphanumeric(3)}")
email = LA(lambda obj: f"{normalize(obj.first_name)}_{normalize(obj.last_name)}@company.com")
created_at = LF(lambda: datetime.now(tz=timezone.utc))
def dict_factory(factory, **kwargs):
"""We create the factory this way to allow internal mechanism of factory boy to work
"""
return factory.stub(**kwargs).__dict__
UserDictFactory = partial(dict_factory, UserFactory)
pytest==4.1.0
pytest-mock==1.10.0
requests==2.21.0
import time
import logging
import sqlite3
import pytest
logger = logging.getLogger("ExampleDBClient")
RECONNECT_SLEEP = 30
RECONNECT_ATTEMPTS = 3
def mock_return(*args, **kwargs):
raise KeyError("Mock Error")
def mock_sleep(*args):
return None
class ExampleDBClient:
def __init__(self, user="dummy", host="localhost", port="1234", database="example"):
self.user = user
self.host = host
self.port = port
self.database = database
self.full_config = dict(user=user, host=host, port=port, database=database)
self.connection = self.connect()
def connect(self, attempts=1):
try:
self.connection = sqlite3.connect(**self.full_config)
logger.info("Connected to %s@%s:%s/%s in attempt %d.",
self.user, self.host, self.port, self.database, attempts)
return self.connection
except Exception as e:
if attempts == RECONNECT_ATTEMPTS:
logger.error("Tried to connect too many times (%u), stopping %s@%s:%s/%s",
attempts, self.user, self.host, self.port, self.database)
raise e
logger.error("Error connecting to %s @ %s : %s / %s",
self.user, self.host, self.port, self.database)
logger.exception(e)
# Reconnection
time.sleep(RECONNECT_SLEEP)
self.connection = self.connect(attempts=attempts + 1)
return self.connection
def test_reconnection_fails(monkeypatch, caplog):
monkeypatch.setattr(time, "sleep", mock_sleep) # Will not sleep
monkeypatch.setattr(sqlite3, "connect", mock_return)
with pytest.raises(Exception):
_ = ExampleDBClient()
print(caplog.records)
assert caplog.records[0].msg == "Error connecting to %s @ %s : %s / %s"
assert caplog.records[1].msg.args[0] == "Mock Error"
assert caplog.records[2].msg == "Error connecting to %s @ %s : %s / %s"
assert caplog.records[3].msg.args[0] == "Mock Error"
assert caplog.records[4].msg == "Tried to connect too many times (%u), stopping %s@%s:%s/%s"
"""
Testing that your program respond as expected in negative situations is very important.
These tests exemplify how to check that some code raises the right Exception.
"""
# TODO BreakingPoint exception
import pytest
def raise_exception():
raise KeyError("This function raises an exception")
# Note that you should use a message CONSTANT instead of a direct string
def test_raise_exception():
with pytest.raises(KeyError):
raise KeyError("Is expected")
with pytest.raises(KeyError):
raise_exception()
with pytest.raises(KeyError) as raised_exception:
raise_exception()
assert raised_exception.msg == "This function raises an exception."
@pytest.mark.xfail() # we expect this test to fail, just to prove the mechanism
def test_raise_unexpected_exception():
raise AttributeError
# It will add an xfail counter in the Result line
# something like: ========== 1 passed, 2 xfailed in 0.08 seconds =================
@pytest.mark.xfail(raises=KeyError)
def test_expected_other_exception():
"""
Some times something fails, you make a test but you cannot find a solution after many hours.
Instead of deleting the test for the suite to pass and forgetting about it; preserve the test,
mark it as xFail and tackle it in the future.
"""
with pytest.raises(AttributeError):
raise_exception()
"""
A typical mock case is changing the output of tim, date and datetime methods.
You may be tempted to make a time.sleep of N seconds. That's wasting your time.
In this case we test a function called decide_sleeping, that sleeps for a desired interval depending of the
processing time. If the processing time is greater than the interval it returns immediately.
This is useful for busy waiting loops.
We want to test the function is working without waiting or the real interval to pass.
In this case we mock both time.time (to return what we want) and time.sleep, to avoid waiting.
We well also use the "spy" mock inserts in the mocked method, os we can assert how it was called.
"""
import time
INTERVAL = 300
START_TIME = 1000
def decide_sleeping(start_time, interval):
elapsed_time = int(time.time() - start_time)
sleep_interval = int(interval - elapsed_time)
if sleep_interval > 0:
time.sleep(sleep_interval)
return
def test_do_sleep(mocker):
"""
mocker is the fixture for unittest.mock. When called, it will remove all the mocks after the given test
See more at https://docs.python.org/3/library/unittest.mock.html#unittest.mock.patch
"""
mocker.patch("time.time", return_value=START_TIME + INTERVAL - 100)
sleeper = mocker.patch("time.sleep")
decide_sleeping(START_TIME, interval=INTERVAL) # 1200 - 1000 = 200, needs to sleep 100 ms
sleeper.assert_called_with(100)
def test_no_sleep(mocker):
mocker.patch("time.time", return_value=START_TIME + INTERVAL + 200)
sleeper = mocker.patch("time.sleep")
decide_sleeping(START_TIME, interval=INTERVAL)
assert not sleeper.called
def test_time_goes_backwards(mocker):
# This probably cannot happen, but it is fun
mocker.patch("time.time", return_value=START_TIME - 100)
sleeper = mocker.patch("time.sleep")
decide_sleeping(START_TIME, interval=INTERVAL)
sleeper.assert_called_with(INTERVAL + 100)
from unittest.mock import call
import pytest
import aux_functions
def sum_three_numbers(num1, num2, num3):
return num1 + num2 + num3
def test_mock_interception(mocker):
aux_functions.add_num(1)
mocked = mocker.patch.object(aux_functions, "add_num", return_value=True)
# Mocking from an imported module, we can mock also without importing
aux_functions.add_num(2)
assert mocked.called_once()
aux_functions.add_num(3)
assert mocked.called_twice()
assert aux_functions.add_num(4) == True
assert aux_functions.num_list == [1] # Only the first one called the function
assert mocked.has_calls(call(2), call(3), call(4))
assert mocked.has_calls(call(4), call(3), call(2))
assert mocked.call_count == 3
assert mocked.called_with(3)
assert mocked.called_with(4, 3, 2)
def test_mock_interception_multiple_parameters(mocker):
# Mocking from a full route module (actually, current one), no need to import sometimes
mocked = mocker.patch("test_mock_with_interception.sum_three_numbers", return_value=0)
sum_three_numbers(1, 2, 3)
sum_three_numbers(4, 5, 6)
mocked.assert_has_calls([call(1, 2, 3)])
mocked.assert_has_calls([call(1, 2, 3), call(4, 5, 6)])
mocked.assert_has_calls([call(4, 5, 6), call(1, 2, 3)], any_order=True)
with pytest.raises(AssertionError):
mocked.assert_has_calls([call(4, 5, 6), call(1, 2, 3)])
import os
import time
import pytest
ENV_VAR_NAME = "DUMMY_VAR"
os.environ["CUSTOM_VAR"] = "Unchanged"
my_dict = {"a": 11, "b": 22}
class MockClass:
attribute1 = 1
attribute2 = 2
def test_monkeypatch_environmentals(monkeypatch):
assert "DUMMY_VAR" not in os.environ
monkeypatch.setenv(ENV_VAR_NAME, "123")
monkeypatch.setenv("CUSTOM_VAR", "Changed")
assert os.environ[ENV_VAR_NAME] == "123"
assert os.environ["CUSTOM_VAR"] == "Changed"
def test_monkeypatch_function(monkeypatch):
monkeypatch.setattr(time, "time", lambda: 12345)
assert time.time() == 12345
assert time.time() == 12345
def test_monkeypatch_delete_attribute(monkeypatch):
instance1 = MockClass()
monkeypatch.delattr(MockClass, "attribute2")
assert instance1.attribute1 == 1
with pytest.raises(AttributeError):
assert instance1.attribute2 == 2
def test_monkeypatch_dicts(monkeypatch):
monkeypatch.setitem(my_dict, "c", 33)
monkeypatch.delitem(my_dict, "b")
assert my_dict == {"a": 11, "c": 33}
def test_unpatching_works():
assert ENV_VAR_NAME not in os.environ
assert os.environ["CUSTOM_VAR"] == "Unchanged"
assert MockClass().attribute2 == 2
assert my_dict == {"a": 11, "b": 22}
"""
Parametrize allows you to run the same test with different inputs and expectations.
Each input will result in a separated test.
As first parameter of the mark, you name the variables in a string, separated by commas.
As second parameter, you input an iterable (a list) with tuples of the values of each case variables.
"""
import pytest
def make_sum(a, b):
return sum([a, b])
# Check the docs here: https://docs.pytest.org/en/latest/parametrize.html
@pytest.mark.parametrize("first_summand, seccond_summand, expected", [
(1, 1, 2),
(1, 2, 3),
(1, -1, 0),
(12, 12, 24)
])
def test_parametrize(first_summand, seccond_summand, expected):
assert make_sum(first_summand, seccond_summand) == expected
# An example of test checking an exception rises. Negative test is also importatnt
@pytest.mark.parametrize("first_summand, seccond_summand, excepction", [
(1, "a", TypeError),
(1, [2], TypeError),
])
def test_parametrize_exception(first_summand, seccond_summand, excepction):
with pytest.raises(excepction):
make_sum(first_summand, seccond_summand)
import pytest
class Simple_Class:
def method(self):
return 1
def another_method(self):
return 2
def test_instance_patch(mocker):
simple_instance = Simple_Class()
another_instance = Simple_Class()
mocker.patch.object(simple_instance, "method", return_value=3)
assert simple_instance.method() == 3
assert another_instance.method() == 1
def test_class_patch(mocker):
mocker.patch.object(Simple_Class, "method", return_value=3)
simple_instance = Simple_Class()
another_instance = Simple_Class()
assert simple_instance.method() == 3
assert another_instance.method() == 3
def test_class_with_side_effect(mocker):
mocker.patch.object(Simple_Class, "method", side_effect=AttributeError("Side effect"))
simple_instance = Simple_Class()
with pytest.raises(AttributeError) as exception:
simple_instance.method()
assert exception.msg == "Side effect"
"""
In this example we will spy on one method without obstructing it.
When we place
"""
import requests
from unittest.mock import call
URL1 = "https://www.python.org/"
URL2 = "https://www.python.org/dev/peps/pep-0008/"
def test_spy_request(mocker):
session = requests.Session() # Use session if you are going to hit the same server several times
spy = mocker.patch.object(session, "get", wraps=session.get)
response1 = session.get(URL1)
response2 = session.get(URL2)
assert response1.status_code == 200
assert response2.status_code == 200
assert spy.call_count == 2
spy.assert_any_call(URL2)
spy.assert_has_calls([call(URL1), call(URL2)])
spy.assert_has_calls([call(URL2), call(URL1)], any_order=True)
def test_another_spy_request(mocker): # Same test but different call to spy
session = requests.Session() # Use session if you are going to hit the same server several times
spy = mocker.spy(session, "get")
response1 = session.get(URL1)
response2 = session.get(URL2)
assert response1.status_code == 200
assert response2.status_code == 200
assert spy.call_count == 2
spy.assert_any_call(URL2)
spy.assert_has_calls([call(URL1), call(URL2)])
spy.assert_has_calls([call(URL2), call(URL1)], any_order=True)
── tests
│   │
│   ├── data
│   │   ├── some_input.json (input fixtures from files)
│   │   └── list_of_names.txt
│   │
│   ├── unitary
│   │   ├── __init__.py
│   │   ├── conftest.py (particular fixtures of this subfolder)
│   │   ├── test_api_basics.py
│   │   ├── test_resource_user.py
│   │   ├── test_other_resource.py
│   │   ├── test_business_logic.py
│   │   ├── test_validation.py
│   │   ├── test_factories.py
│   │
│   ├── integration
│   │   ├── aws
│   │   │   ├── __init__.py
│   │   │   ├── conftest.py
│   │   │   ├── test_some_service.py
│   │   │   └── test_bucket_upload.py
│   │   │  
│   │   └── repositories
│   │      ├── __init__.py
│   │      ├── conftest.py
│   │      ├── test_interface_db.py
│   │      └── test_external2.py
│   │
│   ├── __init__.py
│   ├── common.py
│   ├── conftest.py (most fixtures are stored here)
│   ├── test_smoke.py
│   └── test_cli_aux_commands.py
├── reports/ (stores test results, test logs, and coverage reports)
├── pytest.ini (configures pytest runner)
├── .coveragerc (configures coverage plugin)
└── README.md
@gamesbook
Copy link

How can conftest.py be guaranteed to work when the test environment has no control over how / where a database server has been setup? Do you just accept that this test may fail and move on?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment