Skip to content

Instantly share code, notes, and snippets.

@moriyoshi
Created May 6, 2020 16:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save moriyoshi/8396839e81224e8459d7a16618b37b29 to your computer and use it in GitHub Desktop.
Save moriyoshi/8396839e81224e8459d7a16618b37b29 to your computer and use it in GitHub Desktop.
CPython bytecode instrumentation (convert ordinary methods to async methods)
import bisect
import opcode
import dis
import inspect
import sys
import types
from collections import defaultdict
import httpx
alternative_globals = {}
DUP_TOP = dis.opmap["DUP_TOP"]
POP_TOP = dis.opmap["POP_TOP"]
BUILD_TUPLE = dis.opmap["BUILD_TUPLE"]
LOAD_FAST = dis.opmap["LOAD_FAST"]
LOAD_ATTR = dis.opmap["LOAD_ATTR"]
LOAD_CONST = dis.opmap["LOAD_CONST"]
LOAD_METHOD = dis.opmap["LOAD_METHOD"]
LOAD_GLOBAL = dis.opmap["LOAD_GLOBAL"]
STORE_FAST = dis.opmap["STORE_FAST"]
GET_AWAITABLE = dis.opmap["GET_AWAITABLE"]
YIELD_FROM = dis.opmap["YIELD_FROM"]
EXTENDED_ARG = dis.opmap["EXTENDED_ARG"]
CALL_METHOD = dis.opmap["CALL_METHOD"]
CALL_FUNCTION = dis.opmap["CALL_FUNCTION"]
CALL_FUNCTION_KW = dis.opmap["CALL_FUNCTION_KW"]
CALL_FUNCTION_EX = dis.opmap["CALL_FUNCTION_EX"]
def render_ins(opcode, arg=None):
b = bytearray()
if arg is not None:
if arg > 0x00ffffff:
b.append(EXTENDED_ARG)
b.append((arg >> 24))
if arg > 0x0000ffff:
b.append(EXTENDED_ARG)
b.append((arg >> 16) & 0xff)
if arg > 0x000000ff:
b.append(EXTENDED_ARG)
b.append((arg >> 8) & 0xff)
b.append(opcode)
if arg is not None:
b.append(arg & 0xff)
else:
if sys.version_info.major > 3 or (sys.version_info.major == 3 and sys.version_info.minor >= 6):
b.append(0)
return b
def append_ins(b, opcode, arg=None):
b.extend(render_ins(opcode, arg))
def rewrite_to_async_call(fn, verbs):
code = fn.__code__
if code.co_flags & inspect.CO_COROUTINE:
return code
state = 0
rewritten = False
none_const = None
aread_name = None
resp_store = set()
consts = code.co_consts
names = code.co_names
offset_addr_map = []
offset_map = []
jumps = []
result = bytearray()
sp = 0
ssp = 0
for ins in dis.Bytecode(code):
if ins.opcode == EXTENDED_ARG:
continue
sp += opcode.stack_effect(ins.opcode, ins.arg)
if state == 0:
if ins.opcode == LOAD_FAST and ins.arg == 0:
state = 1
elif ins.opcode == LOAD_FAST and ins.arg in resp_store:
state = 5
elif ins.opcode == LOAD_GLOBAL and ins.argval == "super":
state = 8
ssp = sp
else:
state = 0
elif state == 1:
if ins.opcode == LOAD_ATTR and ins.argval in verbs:
state = 2
ssp = sp
elif ins.opcode == LOAD_ATTR and ins.argval == "refresh_token":
state = 6
ssp = sp
elif ins.opcode == LOAD_METHOD and ins.argval in verbs:
state = 3
ssp = sp
else:
state = 0
elif state == 2:
if sp - ssp == 0:
if ins.opcode == CALL_FUNCTION:
state = 4
elif ins.opcode == CALL_FUNCTION_KW:
state = 4
elif ins.opcode == CALL_FUNCTION_EX:
state = 4
elif state == 3:
if sp - ssp == 0 and ins.opcode == CALL_METHOD:
state = 4
elif state == 4:
if none_const is None:
try:
none_const = consts.index(None)
except ValueError:
none_const = len(consts)
consts = consts + (None,)
append_ins(result, GET_AWAITABLE)
append_ins(result, LOAD_CONST, none_const)
append_ins(result, YIELD_FROM)
offset_addr_map.append(ins.offset)
offset_map.append(len(result))
if ins.opcode == STORE_FAST:
resp_store.add(ins.arg)
state = 0
rewritten = True
elif state == 5:
if ins.opcode == LOAD_ATTR and ins.argval in ("body", "text", "json"):
if aread_name is None:
try:
aread_name = names.index("aread")
except ValueError:
aread_name = len(names)
names = names + ("aread",)
append_ins(result, DUP_TOP)
append_ins(result, LOAD_METHOD, aread_name)
append_ins(result, CALL_METHOD, 0)
append_ins(result, GET_AWAITABLE)
append_ins(result, LOAD_CONST, none_const)
append_ins(result, YIELD_FROM)
append_ins(result, POP_TOP)
offset_addr_map.append(ins.offset)
offset_map.append(len(result))
rewritten = True
state = 0
elif state == 6:
if sp - ssp == 0:
if ins.opcode == CALL_FUNCTION:
state = 7
elif ins.opcode == CALL_FUNCTION_KW:
state = 7
elif ins.opcode == CALL_FUNCTION_EX:
state = 7
elif state == 7:
if none_const is None:
try:
none_const = consts.index(None)
except ValueError:
none_const = len(consts)
consts = consts + (None,)
prev = len(result)
append_ins(result, GET_AWAITABLE)
append_ins(result, LOAD_CONST, none_const)
append_ins(result, YIELD_FROM)
offset_addr_map.append(ins.offset)
offset_map.append(len(result))
state = 0
rewritten = True
elif state == 8:
if sp - ssp == 0 and ins.opcode == CALL_FUNCTION:
state = 1
co = len(result)
append_ins(result, ins.opcode, ins.arg)
abs_ = ins.opcode in dis.hasjabs
rel = ins.opcode in dis.hasjrel
if abs_ or rel:
jumps.append((co, ins, rel, len(result) - co))
# backpatching
for addr, ins, rel, l in jumps:
if rel:
offset = ins.offset + l
target = ins.argval
s = bisect.bisect_right(offset_addr_map, target) - 1
e = bisect.bisect_right(offset_addr_map, offset) - 1
if s >= 0:
target += offset_map[s] - offset_addr_map[s]
if e >= 0:
offset += offset_map[e] - offset_addr_map[e]
b = render_ins(ins.opcode, target - offset)
result[addr:addr + len(b)] = b
else:
i = bisect.bisect_right(offset_addr_map, ins.arg) - 1
if i < 0:
continue
b = render_ins(ins.opcode, ins.arg + offset_map[i] - offset_addr_map[i])
result[addr:addr + len(b)] = b
return (
code if not rewritten
else code.replace(co_code=bytes(result), co_flags=(code.co_flags | inspect.CO_COROUTINE), co_consts=consts, co_names=names, co_stacksize=code.co_stacksize + 96)
)
class CompatibleRequest(httpx.Request):
@property
def body(self):
self.read()
return self.content
class CompatibleAsyncClient(httpx.AsyncClient):
async def request(self, *args, **kwargs):
kwargs.pop("verify", None)
kwargs.pop("proxies", None)
resp = await super().request(*args, **kwargs)
resp.request.__class__ = CompatibleRequest
return resp
async def get(self, *args, **kwargs):
return await self.request("GET", *args, **kwargs)
async def options(self, *args, **kwargs):
return await self.request("OPTIONS", *args, **kwargs)
async def head(self, *args, **kwargs):
allow_redirects = kwargs.pop("allow_redirects", False)
return await self.request("HEAD", *args, allow_redirects=allow_redirects, **kwargs)
async def post(self, *args, **kwargs):
return await self.request("POST", *args, **kwargs)
async def put(self, *args, **kwargs):
return await self.request("PUT", *args, **kwargs)
async def patch(self, *args, **kwargs):
return await self.request("PATCH", *args, **kwargs)
async def delete(self, *args, **kwargs):
return await self.request("DELETE", *args, **kwargs)
def instrument(type_):
new_type = type(
type_.__name__,
(CompatibleAsyncClient,),
{
k: v
for k, v in type_.__dict__.items()
if k not in ("__module__", "__dict__", "__weakref__")
}
)
for k, v in new_type.__dict__.items():
if isinstance(v, types.FunctionType):
if v.__closure__ is not None:
new_closure = tuple(
types.CellType(new_type)
if cell.cell_contents is type_
else cell
for cell in v.__closure__
)
else:
new_closure = v.__closure__
globals = alternative_globals.get(v.__module__)
if globals is None:
globals = alternative_globals[v.__module__] = {
k: (new_type if v is type_ else v)
for k, v in v.__globals__.items()
}
v = types.FunctionType(
rewrite_to_async_call(v, ("request", "get", "options", "head", "post", "put", "patch", "delete")),
globals,
v.__name__,
v.__defaults__,
new_closure
)
setattr(new_type, k, v)
return new_type
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment