Last active
March 17, 2018 21:12
-
-
Save Fuyukai/12d433cfd882ebdb81d327d17e7c5902 to your computer and use it in GitHub Desktop.
Checking if a function is being awaited
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dis | |
import inspect | |
isasyncfn_original = inspect.iscoroutinefunction | |
def _is_calling_from_async(frameinfo): | |
""" | |
Checks if the frame is awaiting the next stack frame. | |
""" | |
frame = frameinfo.frame | |
code = frame.f_code | |
idx = frame.f_lasti // 2 | |
instructions = dis.get_instructions(code) | |
for ins in list(instructions)[idx:]: | |
# check if we modify the stack | |
# if we do, we didn't call as `await x` | |
if dis.stack_effect(ins.opcode, ins.arg) != 0: | |
return | |
# we didn't modify the stack, and we got an awaitable | |
# this means we tried calling as AWAIT | |
# so return true | |
if ins.opname == "GET_AWAITABLE": | |
return True | |
return False | |
def ensure_await_state(fn): | |
def inner(*args, **kwargs): | |
frameinfo = inspect.stack()[1] | |
if _is_calling_from_async(frameinfo) != isasyncfn_original(fn): | |
raise ValueError("Function was called wrong") | |
return fn(*args, **kwargs) | |
return inner | |
@ensure_await_state | |
async def a_test(): | |
return "Test" | |
async def test(): | |
# 1st, no error | |
x = await a_test() | |
print("1st", x) | |
# 2nd, error | |
x = a_test() | |
print("2nd", x) | |
def patch_all(): | |
# Patching function is left as an excercise to the reader. | |
# Notably, you probably want to change builtins.__build_class__, and walk sys.modules to | |
# hot-patch every single function. | |
# Also, add an import hook which patches every standalone function in every module when | |
# imported, and perhaps do something with the C-Python border. | |
pass | |
try: | |
patch_all() | |
gen = test() | |
gen.send(None) | |
except StopIteration as e: | |
print("gen returned", e.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/usr/bin/python3.6 /home/laura/dev/misc/hackery.py | |
Test | |
Traceback (most recent call last): | |
File "/home/laura/dev/misc/hackery.py", line 69, in <module> | |
gen.send(None) | |
File "/home/laura/dev/misc/hackery.py", line 53, in test | |
x = a_test() | |
File "/home/laura/dev/misc/hackery.py", line 36, in inner | |
raise ValueError("Function was called wrong") | |
ValueError: Function was called wrong | |
Process finished with exit code 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment