Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import asyncio
import functools
import inspect
import os
import unittest
from unittest import mock
class AsyncTestCaseMeta(type(unittest.TestCase)):
def __new__(mcls, name, bases, ns):
for attrname, attr in ns.items():
if (attrname.startswith('test_') and
inspect.iscoroutinefunction(attr)):
ns[attrname] = mcls._sync_wrap(attr)
return super().__new__(mcls, name, bases, ns)
@staticmethod
def _sync_wrap(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return asyncio.get_event_loop().run_until_complete(func(*args, **kwargs))
return wrapper
class AsyncTestCase(unittest.TestCase, metaclass=AsyncTestCaseMeta):
pass
class MyTest(AsyncTestCase):
def test_sync(self):
pass
async def test_async(self):
await asyncio.sleep(0.0)
@mock.patch("os.path", {})
async def test_async_with_mock(self):
await asyncio.sleep(0.0)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.