Skip to content

Instantly share code, notes, and snippets.

@skrungly
Created July 16, 2018 02:45
Show Gist options
  • Save skrungly/5ef65ab3fb414df3857074872837c953 to your computer and use it in GitHub Desktop.
Save skrungly/5ef65ab3fb414df3857074872837c953 to your computer and use it in GitHub Desktop.
C++ style I/O implementation in Python. Oh dear.
import dis
from types import CodeType
def _patch_code(code: CodeType, **kwargs):
"""Create a new CodeType object with modified attributes."""
code_attrs = {}
# Collect the original CodeType attributes
for attr in dir(code):
if "__" not in attr:
code_attrs[attr] = getattr(code, attr)
# Patch the new attributes over the original ones
code_attrs.update(kwargs)
new_object = CodeType(
code_attrs["co_argcount"],
code_attrs["co_kwonlyargcount"],
code_attrs["co_nlocals"],
code_attrs["co_stacksize"],
code_attrs["co_flags"],
code_attrs["co_code"],
code_attrs["co_consts"],
code_attrs["co_names"],
code_attrs["co_varnames"],
code_attrs["co_filename"],
code_attrs["co_name"],
code_attrs["co_firstlineno"],
code_attrs["co_lnotab"]
)
return new_object
def _assemble(*instructions):
"""
Assemble CPython bytecode into a byte-string given
any amount of two-tuples containing: (op_name, arg_value)
"""
code = ""
for op_name, arg in instructions:
# Some instructions don't take arguments, so just null it.
if arg is None:
arg = 0
# Find the opcode so we can create the two-byte instruction
# from the opcode itself and the argument number.
op_code = dis.opmap[op_name]
code += chr(op_code) + chr(arg)
# We can't use `str.encode()` here because some opcodes are
# greater than 128 (such as CALL_FUNCTION -> 131) so they wouldn't
# be encoded to ASCII, and UTF-8 would obviously yield inconsistent
# results due to the possibility of multi-byte characters.
return bytes(ord(char) for char in code)
def _safe_search(tup, *items):
"""Search a tuple, or add the item if it's not present."""
indices = []
for item in items:
# Create a new tuple with the item if it's not already present
if item not in tup:
tup = *tup, item
indices.append(tup.index(item))
return (*tup, *indices)
def _exec_code(code: CodeType, *args, **kwargs):
"""Execute a CodeType object with args and kwargs."""
# We re-assign the bytecode of this empty
# function with the CodeType object so that
# it can be executed in a normal manner.
util = lambda *args, **kwargs: None
util.__code__ = code
return util(*args, **kwargs)
def cpp_stdio(func):
"""
Modifies the bytecode of a function so that C++ style `cin` and
`cout` calls are possible in place of `input` and `print`, then
executes it.
Example:
>>> @cpp_stdio
... def hello_name():
... cout << "Please enter your name: ";
... cin >> name;
...
... cout << "Hello, " << name << "!" << endl;
...
Note: I will hurt you if you unironically use this. :D
"""
def _patch_cin(code: CodeType):
"""Patch instances of C++ style `cin` in the function."""
# The attributes themselves are read-only so we just
# have to patch these copies onto the original object.
nlocals = code.co_nlocals
varnames = code.co_varnames
consts = code.co_consts
new_code = code.co_code
# We need cin_num and input_num for instruction args.
*names, cin_num, input_num = _safe_search(
code.co_names,
"cin", "input"
)
# This will be used to find where `cin` is called.
cin_start = _assemble(
("LOAD_GLOBAL", cin_num)
)
# This list will contain all of the implicitly-declared
# variables. This means we can declare them locally from
# the `cin` call alone, which is the 'pythonic' twist. :P
imp_decl = []
start_pos = 0
while True:
# Attempt to find another `cin` in the function,
# and stop looking if one couldn't be found.
start_pos = new_code.find(cin_start, start_pos)
if start_pos < 0:
break
# `cin` calls are 4 instructions x 2 bytes = 8 bytes
end_pos = start_pos + 8
cin_call = new_code[start_pos:end_pos]
# The third byte is the local arg number
# of the variable that is being changed.
store_num = cin_call[3]
# Define the variable in the function's local
# scope if it hasn't yet been declared locally.
if cin_call[2] == dis.opmap["LOAD_GLOBAL"]:
*consts, none_num = _safe_search(
consts,
None
)
# We'll need to keep track of this to replace it
# if the variable is used later in the function.
prev_store_num = store_num
# Add the variable to the local scope declarations
*varnames, store_num = _safe_search(
varnames,
names[store_num]
)
nlocals += 1
imp_decl.append(
(prev_store_num, store_num)
)
# This is the `var = input()` bytecode. It directly
# replaces the `cin` calls, so that `cin` doesn't even
# need to be defined at all for this to work. Snazzy!
changed_code = _assemble(
("LOAD_GLOBAL", input_num),
("CALL_FUNCTION", 0),
("STORE_FAST", store_num)
)
new_code = new_code[:start_pos] + changed_code + new_code[end_pos:]
# Stop the intepreter from treating implicity-declared
# local variables as potentially undefined global variables.
for prev_store, new_store in imp_decl:
# The global-loading bytecode
wrong_decl = _assemble(
("LOAD_GLOBAL", prev_store),
)
# The (correct) local-loading bytecode
new_decl = _assemble(
("LOAD_FAST", new_store),
)
new_code = new_code.replace(wrong_decl, new_decl)
return _patch_code(code,
co_code=new_code,
co_names=tuple(names),
co_consts=tuple(consts),
co_varnames=tuple(varnames),
co_nlocals=nlocals
)
def _patch_cout(code: CodeType):
"""
Patch instances of C++ style `cout` in the function.
Note: I'm very aware that one can simply make a custom class
which overrides the __lshift__ magic method and does a bunch
of fancy stuff, and that's what I had originally.
This is more fun though :D
"""
# We need to patch these over the read-only attributes of `code`
new_code = code.co_code
consts = code.co_consts
names = code.co_names
# Find the const arg numbers for the two bools,
# newline and empty strings, and the `print` kwargs.
*consts, false_num, true_num, newln_num, empty_str, print_kws = _safe_search(
consts,
False, True, "\n", "", ("sep", "end", "flush")
)
# Find the global arg numbers of `count`, `endl` and `print`
*names, cout_num, endl_num, print_num = _safe_search(
names,
"cout", "endl", "print"
)
# `cout` calls will always begin with this instruction
cout_start = _assemble(
("LOAD_GLOBAL", cout_num)
)
# Each value is separated by a `<<`, so we'll use this.
separator = _assemble(
("BINARY_LSHIFT", None)
)
# `cout` calls always end with this instruction
cout_end = _assemble(
("POP_TOP", None)
)
# And this is what `endl` will appear as.
endl_value = _assemble(
("LOAD_GLOBAL", endl_num)
)
start_pos = 0
while True:
# Find the boundaries and the bytecode of the `cout` call
start_pos = new_code.find(cout_start, start_pos)
if start_pos < 0:
break
end_pos = new_code.find(cout_end, start_pos)
cout_call = new_code[start_pos:end_pos]
# Cut off the `cout` part, and remove the `<<` separators.
out_values = cout_call[2:].replace(separator, b"")
# Push the print function onto the stack
changed_code = _assemble(
("LOAD_GLOBAL", print_num)
)
# Add each value to be printed from the cout call
changed_code += out_values
if out_values.endswith(endl_value):
changed_code = changed_code[:-2]
# `cout` typically doesn't have a separator.
changed_code += _assemble(
("LOAD_CONST", empty_str) # sep=''
)
# Each argument occupies 2 bytes, therefore we can
# just divide the size by 2 to get the amount of them.
args_length = len(out_values) // 2
if out_values.endswith(endl_value):
# `endl` follows `print` defaults, but all print
# calls will have the 3 kwargs for the sake
# of consistency (and I'm lazy!)
changed_code +=_assemble(
("LOAD_CONST", newln_num), # end='\n'
("LOAD_CONST", true_num), # flush=True
)
# Account for the stripped off `endl` value
args_length -= 1
else:
changed_code += _assemble(
("LOAD_CONST", empty_str), # end=''
("LOAD_CONST", false_num), # flush=False
)
# Loads the kwarg names and call the function
changed_code += _assemble(
("LOAD_CONST", print_kws),
("CALL_FUNCTION_KW", 3 + args_length)
)
new_code = new_code[:start_pos] + changed_code + new_code[end_pos:]
return _patch_code(code,
co_code=new_code,
co_consts=tuple(consts),
co_names=tuple(names)
)
def wrapper(*args, **kwargs):
code = func.__code__
# We only need to patch `cin` and `count` if they're used.
if "cin" in code.co_names:
code = _patch_cin(code)
if "cout" in code.co_names:
code = _patch_cout(code)
# Finally, execute the patched code as if nothing has happened :P
_exec_code(code, *args, **kwargs)
return wrapper
if __name__ == '__main__':
@cpp_stdio
def addition():
cout << "Enter a number: ";
cin >> x;
cout << "And another: ";
cin >> y;
result = int(x) + int(y)
cout << x << " + " << y << " = " << result << endl;
addition()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment