Skip to content

Instantly share code, notes, and snippets.

@pmeier
Created April 1, 2021 08:26
Show Gist options
  • Save pmeier/003b2ac1d26124581f9b69300cbfbad2 to your computer and use it in GitHub Desktop.
Save pmeier/003b2ac1d26124581f9b69300cbfbad2 to your computer and use it in GitHub Desktop.
A case for pytest in PyTorch

A case for pytest in PyTorch

This is a short post about why I think it would be beneficial for PyTorch to not only use pytest as test runner, but also rely on the other features it provides.

Disclaimer

My experience with the PyTorch test suite is limited as of now. Thus, it might very well be that my view on things is too naive. In that case I'm happy to hear about examples where and adoption of pytest would make a use case significantly harder or outright impossible.

Setup

test_pytorch.py:

from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_device_type import (
    instantiate_device_type_tests,
    onlyCPU,
    onlyCUDA,
    onlyOnCPUAndCUDA,
)


class TestFoo(TestCase):
    @onlyCPU
    def test_bar(self, device):
        pass


instantiate_device_type_tests(TestFoo, globals())


class TestSpam(TestCase):
    @onlyCUDA
    def test_ham(self, device):
        pass

    @onlyOnCPUAndCUDA
    def test_eggs(self, device):
        if str(device) == "cpu":
            raise AssertionError


instantiate_device_type_tests(TestSpam, globals())

test_pytest.py:

import torch
import pytest


def devices(*devices_):
    return pytest.mark.parametrize(
        "device", [torch.device(device) for device in devices_], ids=devices_
    )


cpu = devices("cpu")
cuda = devices("cuda")
cpu_and_cuda = devices("cpu", "cuda")


class TestFoo:
    @cpu
    def test_bar(self, device):
        pass


class TestSpam:
    @cuda
    def test_ham(self, device):
        pass

    @cpu_and_cuda
    def test_eggs(self, device):
        if str(device) == "cpu":
            raise AssertionError

Output

pytest -rA test_pytorch.py

PASSED test_pytorch.py::TestFooCPU::test_bar_cpu
PASSED test_pytorch.py::TestSpamCUDA::test_eggs_cuda
PASSED test_pytorch.py::TestSpamCUDA::test_ham_cuda
SKIPPED [2] test_pytorch.py:11: Only runs on cpu
SKIPPED [2] test_pytorch.py:20: Only runs on cuda
SKIPPED [1] test_pytorch.py:24: onlyOnCPUAndCUDA: doesn't run on meta
FAILED test_pytorch.py::TestSpamCPU::test_eggs_cpu - AssertionError

pytest -rA test_pytest.py

PASSED test_pytest.py::TestFoo::test_bar[cpu]
PASSED test_pytest.py::TestSpam::test_ham[cuda]
PASSED test_pytest.py::TestSpam::test_eggs[cuda]
FAILED test_pytest.py::TestSpam::test_eggs[cpu] - AssertionError
  • In PyTorch style the tests are double namespaced with the device.
  • In PyTorch style all tests are instatiated with every device and skipped if the test should be run with a specific device. This clutters the output.
  • In PyTorch style the message of the skipped tests is not expressive as it misses the name of the test that is skipped.

Selection

Situation PyTorch style pytest style
Select one test case and run it against all devices test_pytorch.py -k "TestSpam" test_pytest.py::TestSpam
Select one test case and run it against a specific device test_pytorch.py::TestSpamCUDA test_pytest.py::TestSpam -k cuda
Select one test and run it against all devices test_pytorch.py -k "TestSpam and test_eggs" test_pytest.py::TestSpam::test_eggs
Select one test and run it against a specific device test_pytorch.py::TestSpamCPU::test_eggs_cpu test_pytest.py::TestSpam::test_eggs[cpu]

Popular IDEs (PyCharm, VSCode) provide the option to run / debug tests without dropping into a terminal. They rely on the name within the file to select the test (case) to be run. Since in PyTorch style the tests are renamed, one cannot use this feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment