Skip to content

Instantly share code, notes, and snippets.

@hashbrowncipher
Created October 22, 2018 20:30
Show Gist options
  • Save hashbrowncipher/6fbcb3b6b3bde71af08f301e2b1cf3ee to your computer and use it in GitHub Desktop.
Save hashbrowncipher/6fbcb3b6b3bde71af08f301e2b1cf3ee to your computer and use it in GitHub Desktop.
Disassembles simple Python lambdas
import dis
from itertools import chain
from types import CodeType
def _get_func_args(stack, count):
return reversed([stack.pop() for i in range(count)])
def visit_func(stack, isn):
args = _get_func_args(stack, isn.arg)
fn = stack.pop()
stack.append("{}({})".format(fn, ", ".join(args)))
def visit_load_global(stack, isn):
stack.append(isn.argval)
def visit_load_const(stack, isn):
if isinstance(isn.argval, tuple) or isinstance(isn.argval, CodeType):
stack.append(isn.argval)
else:
stack.append(repr(isn.argval))
def visit_func_kw(stack, isn):
kwargs = dict()
kwargs_keys = stack.pop()
for key in reversed(kwargs_keys):
kwargs[key] = stack.pop()
kwargs_items = ["{}={}".format(k, v) for (k, v) in kwargs.items()]
args = _get_func_args(stack, isn.arg - len(kwargs_items))
args_str = ", ".join(chain(args, reversed(kwargs_items)))
fn = stack.pop()
stack.append("{}({})".format(fn, args_str))
def visit_func_ex(stack, isn):
if isn.arg % 2 == 1:
mapargs = stack.pop()
if not isinstance(mapargs, tuple):
mapargs = (mapargs,)
mapargs = tuple("**" + arg for arg in mapargs)
else:
mapargs = tuple()
args = stack.pop()
posargs = args[0]
if not isinstance(posargs, tuple):
posargs = (posargs,)
listargs = args[1:]
listargs = tuple("*" + arg for arg in listargs)
args = posargs + listargs + mapargs
args = ", ".join(args)
fn = stack.pop()
stack.append("{}({})".format(fn, args))
def visit_return(stack, isn):
return stack.pop()
def visit_compare_op(stack, isn):
op = dis.cmp_op[isn.arg]
two = stack.pop()
one = stack.pop()
stack.append("{} {} {}".format(one, op, two))
def visit_build_tuple(stack, isn):
stack.append(tuple(_get_func_args(stack, isn.arg)))
def visit_load_fast(stack, isn):
stack.append(isn.argval)
def visit_load_method(stack, isn):
stack.append("{}.{}".format(stack.pop(), isn.argval))
def visit_unpack_sequence(stack, isn):
print((stack, isn))
def visit_make_function(stack, isn):
name = stack.pop()
code = stack.pop()
if (isn.arg >> 3) & 1:
stack.pop()
stack.append(unlambda(code))
visitors = dict(
COMPARE_OP=visit_compare_op,
LOAD_GLOBAL=visit_load_global,
LOAD_CONST=visit_load_const,
LOAD_FAST=visit_load_fast,
LOAD_DEREF=visit_load_fast,
LOAD_CLOSURE=visit_load_fast,
LOAD_METHOD=visit_load_method,
BUILD_LIST=visit_build_tuple,
BUILD_MAP_UNPACK_WITH_CALL=visit_build_tuple,
BUILD_TUPLE=visit_build_tuple,
BUILD_TUPLE_UNPACK_WITH_CALL=visit_build_tuple,
CALL_METHOD=visit_func,
CALL_FUNCTION=visit_func,
CALL_FUNCTION_KW=visit_func_kw,
CALL_FUNCTION_EX=visit_func_ex,
MAKE_FUNCTION=visit_make_function,
RETURN_VALUE=visit_return,
UNPACK_SEQUENCE=visit_unpack_sequence,
)
def unlambda(code):
stack = list()
try:
code = getattr(code, "__code__")
except AttributeError:
pass
isns = iter(dis.Bytecode(code))
args = code.co_varnames[0 : code.co_argcount]
args = ", ".join(args)
if args:
args = " " + args
while True:
isn = next(isns)
try:
visitor = visitors[isn.opname]
except KeyError:
print(stack)
print(isn)
# print(list(isns))
raise
ret = visitor(stack, isn)
if ret:
return "lambda{}: {}".format(args, ret)
def assertEquals(a, b):
if a != b:
print("{} != {}".format(a, b))
else:
print("{} == {}".format(a, b))
def test():
d = unlambda
assertEquals(
d(lambda: foo("bar", "baz", 5)), "lambda: foo('bar', 'baz', 5)"
)
assertEquals(d(lambda: foo(var, "baz", 5)), "lambda: foo(var, 'baz', 5)")
assertEquals(
d(lambda: foo(baz, "bar", baz=4, quux=5)),
"lambda: foo(baz, 'bar', baz=4, quux=5)",
)
assertEquals(d(lambda: foo(a, *c)), "lambda: foo(a, *c)")
assertEquals(d(lambda: foo(a, *c, **d)), "lambda: foo(a, *c, **d)")
assertEquals(d(lambda: foo(a, **d, **e)), "lambda: foo(a, **d, **e)")
assertEquals(
d(lambda: foo(a, 6, 7, *b, *c)), "lambda: foo(a, 6, 7, *b, *c)"
)
assertEquals(d(lambda: foo(a, 6, 7, *b)), "lambda: foo(a, 6, 7, *b)")
assertEquals(d(lambda: 5), "lambda: 5")
assertEquals(d(lambda: foo()), "lambda: foo()")
assertEquals(d(lambda: foo(a=5, b=6, c=7)), "lambda: foo(a=5, b=6, c=7)")
assertEquals(d(lambda: foo(bar=baz(5, 8, 8), b=6, c=7)), "")
assertEquals(
d(lambda: foo(bar=lambda f: baz(f, 8, 8), b=6, c=7)),
"lambda: foo(bar=lambda f: baz(f, 8, 8), b=6, c=7)",
)
assertEquals(d(lambda x, y: x != y), "lambda x, y: x != y")
assertEquals(d(lambda x, y: 4), "lambda x, y: 4")
assertEquals(d(lambda x, y: "foo"), "lambda x, y: 'foo'")
assertEquals(d(lambda x, y: range(20)), "lambda x, y: range(20)")
if __name__ == "__main__":
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment