Skip to content

Instantly share code, notes, and snippets.

@yonatanzunger
Created August 11, 2022 00:05
Show Gist options
  • Save yonatanzunger/9a995e850d1035e56afcc8cae5027798 to your computer and use it in GitHub Desktop.
Save yonatanzunger/9a995e850d1035e56afcc8cae5027798 to your computer and use it in GitHub Desktop.
Advanced Python: High Performance with Codegen (Example)
import dis
import io
from types import CodeType, FunctionType
from typing import Any, Callable, List, NamedTuple, Tuple
# Opcodes
_LOAD_FAST = dis.opname.index('LOAD_FAST')
_LOAD_CONST = dis.opname.index('LOAD_CONST')
_COMPARE_OP = dis.opname.index('COMPARE_OP')
_JUMP_IF_FALSE_OR_POP = dis.opname.index('JUMP_IF_FALSE_OR_POP')
_RETURN_VALUE = dis.opname.index('RETURN_VALUE')
# A Comparison represents a single comparison of the form 'args[variableIndex] OP testValue'. The
# operation should be the index of the corresponding operation in dis.cmp_op; e.g.,
# dis.cmp_op.index('==') = 2.
class Comparison(NamedTuple):
variableIndex: int
comparisonOp: int
testValue: Any
def makeFastComparator(
numArguments: int,
comparisons: List[Comparison],
filename: str = '__generated-code__',
funcName: str = '_compiled_CompareKey',
) -> Callable:
"""Generate a function which will evaluate a bunch of comparisons very efficiently.
This function will return a function which takes (numArguments) arguments, and which is
equivalent to
def funcName(x1, x2, ... xN):
return (x[i1] OP1 test1) and (x[i2] OP2 test2) and ...
where (i1, OP1, and test1) ... are the tuples passed in 'comparisons'. So for example,
makeFastComparator(2, [(0, dist.cmp_op.index('=='), b'1234'),
(1, dist.cmp_op.index('<=', b'3333')])
returns a function equivalent to
def myComparator(arg1, arg2):
return arg1 == b'1234' and arg2 <= b'3333'
The resulting function is significantly more efficient than the function you could generate
naively using for loops, etc.; in fact, it is bytecode-identical to what would be produced
from the 'and' statement described above. It can therefore be used in an innermost loop and
be really fast.
DEBUGGING NOTES:
- The resulting function will have its __name__ and __module__ values set to the funcName and
filename arguments to this function, respectively. Importantly, this will affect how the
function shows up in profiles, so you can set this if you want different return values from
makeFastComparator to show up as different profiling items.
- This function will have fake "line numbers" in it which will show up in tracebacks if
something goes wrong. Line numbers 1...N correspond to evaluating the corresponding comparison
operation; line number N+1 is the final 'return' statement. This can help you debug problems
if e.g. you passed a bogus argument to the function.
"""
constants: Tuple[Any, ...]
if not comparisons:
# Simple case: If there are no comparisons, this is just the function that returns True.
constants = (None, True)
bytecode = bytes([_LOAD_CONST, 1, _RETURN_VALUE, 0])
lnotab = bytes([0, 1])
else:
# We're going to generate bytecode for each comparison (index, operation, value) that looks
# like
# 1: LOAD_FAST <index1>
# LOAD_CONST <value1>
# COMPARE_OP <operation1>
# JUMP_IF_FALSE_OR_POP <return>
# 2: LOAD_FAST <index2>
# LOAD_CONST <value2>
# COMPARE_OP <operation2>
# .... repeated for each value except the last one ...
# N: LOAD_FAST <indexN>
# LOAD_CONST <valueN>
# COMPARE_OP <operationN>
# N+1: RETURN_VALUE
#
# Here <return> is the offset of the RETURN_VALUE instruction, which (since all
# instructions are exactly two bytes) = 8 * (ncmps - 1) + 6 = 8 * ncmps - 2. The numbers
# on the left are line numbers, so that you can debug the output more easily!
constants = (None, *(compare[2] for compare in comparisons))
COMPARISON_LENGTH = 4 * 2
writer = io.BytesIO()
returnAddress = COMPARISON_LENGTH * len(comparisons) - 2
if returnAddress > 255:
raise RuntimeError(
'Hmm. Handling jumps of more than 255 bytes would require more '
'intelligent code than we\'ve written here.'
)
# Helpers for writing out the table of line numbers. This format is documented in
# https://github.com/python/cpython/blob/master/Objects/lnotab_notes.txt
# but we're just using its simplest form. incrementLine() basically says "the point
# where we're about to write is the next line of code."
linenoWriter = io.BytesIO()
lastLineStart = 0
def incrementLine() -> None:
nonlocal lastLineStart
currentPos = writer.tell()
linenoWriter.write(bytes([currentPos - lastLineStart, 1]))
lastLineStart = currentPos
for opNum, comparison in enumerate(comparisons):
index, cmpOp, _ = comparison
if index < 0 or index >= numArguments:
raise ValueError(
f'Got bad index {index} for comparison request with only '
f'{numArguments} arguments!'
)
# NB dis.cmp_op[-1] is the constant 'BAD', which is invalid!
if cmpOp < 0 or cmpOp >= len(dis.cmp_op) - 1:
raise ValueError(f'Got bad comparison operation {cmpOp}')
incrementLine()
writer.write(bytes([_LOAD_FAST, index, _LOAD_CONST, opNum + 1, _COMPARE_OP, cmpOp]))
if opNum != len(comparisons) - 1:
writer.write(bytes([_JUMP_IF_FALSE_OR_POP, returnAddress]))
else:
incrementLine()
writer.write(bytes([_RETURN_VALUE, 0]))
bytecode = writer.getvalue()
lnotab = linenoWriter.getvalue()
# The CodeType class isn't properly documented, but its call syntax is defined by the function
# code_new() in https://github.com/python/cpython/blob/master/Objects/codeobject.c . That file
# also documents what legal values of the flags are. Note that the arguments are just the values
# of the fields of a code object: co_argcount, co_kwonlyargcount, co_nlocals, co_stacksize,
# co_flags, co_code, co_consts, co_names, co_varnames, co_filename, co_name, co_firstlineno,
# co_lnotab, and two optional arguments, co_freevars and co_cellvars.
# Some important but non-obvious reminders:
# - Arguments count as locals, so if co_nlocals < co_argcount + co_kwonlyargcount, very
# surprising things will happen to you.
# - co_varnames should have co_nlocals elements in it, or various debug operations may fail.
# - Many things in the CPython codebase assume that co_constants[0] is None.
# You can find useful tips for instantiating CodeType objects at
# https://stackoverflow.com/questions/16064409/how-to-create-a-code-object-in-python
code = CodeType(
numArguments, # Normal arguments
0, # kw-only arguments
numArguments, # local variables
2, # stack size
0, # flags
bytecode,
constants,
tuple(), # global variable names used; none.
tuple(f'arg{i}' for i in range(numArguments)), # names for our locals for debug
filename,
funcName,
0, # firstlineno
lnotab, # line number table
)
# The arguments to FunctionType are the CodeType object and the dict of globals which may be
# used by this function, which should have the same keys as the 'co_names' argument to CodeType.
# mypy doesn't seem to understand the correct arguments to this, though.
return FunctionType(code, {})
import dis
import unittest
from humu.storage.impl.fast_comparator import Comparison, makeFastComparator
class testFastComparator(unittest.TestCase):
def testSimpleComparator(self):
# This should make a function
# def compareTwo(x1, x2):
# return x2 == b'1234' and x1 > b'0000'
compareTwo = makeFastComparator(
2,
[
Comparison(1, dis.cmp_op.index('=='), b'1234'),
Comparison(0, dis.cmp_op.index('>'), b'0000'),
],
)
self.assertTrue(compareTwo(b'5555', b'1234'))
# Fail the first test
self.assertFalse(compareTwo(b'5555', b'1000'))
# Fail the second test
self.assertFalse(compareTwo(b'****', b'1234'))
def testCheckNothing(self):
noTestsCompare = makeFastComparator(2, [])
self.assertTrue(noTestsCompare('asdfa', 'dfhgdf'))
def testCompareThree(self):
compareThree = makeFastComparator(
3,
[
Comparison(0, dis.cmp_op.index('=='), b'1234'),
Comparison(1, dis.cmp_op.index('=='), b'2345'),
Comparison(2, dis.cmp_op.index('=='), b'3456'),
],
)
self.assertTrue(compareThree(b'1234', b'2345', b'3456'))
self.assertFalse(compareThree(b'0000', b'2345', b'3456'))
self.assertFalse(compareThree(b'1234', b'0000', b'3456'))
self.assertFalse(compareThree(b'1234', b'2345', b'0000'))
def testCompareList(self):
compareThree = makeFastComparator(
3,
[
Comparison(0, dis.cmp_op.index('in'), [b'1234', b'2345', b'3456']),
Comparison(1, dis.cmp_op.index('in'), [b'1234']),
Comparison(2, dis.cmp_op.index('in'), [b'2345', b'3456']),
],
)
self.assertTrue(compareThree(b'1234', b'1234', b'2345'))
self.assertTrue(compareThree(b'2345', b'1234', b'3456'))
self.assertFalse(compareThree(b'0000', b'1234', b'3456'))
self.assertFalse(compareThree(b'1234', b'1234', b'1234'))
self.assertFalse(compareThree(b'1234', b'1234', b'0000'))
if __name__ == '__main__':
unittest.main()
from dis import cmp_op
from typing import List, NamedTuple, Optional
from fast_comparator import Comparison
class KeyRange(NamedTuple):
exact_value: Optional[bytes]
min_value: Optional[bytes]
max_value: Optional[bytes]
def matches(self, value: bytes) -> bool:
if self.exact_value is not None:
return value == self.exact_value
return (self.min_value is None or value >= self.min_value) and (
self.max_value is None or value < self.max_value
)
def asComparators(self, index: int, comparisons: List[Comparison]) -> None:
if self.exact_value is not None:
comparisons.append(Comparison(index, cmp_op.index('=='), self.exact_value))
return
if self.min_value is not None:
comparisons.append(Comparison(index, cmp_op.index('>='), self.min_value))
if self.max_value is not None:
comparisons.append(Comparison(index, cmp_op.index('<'), self.max_value))
THEN IN THE QUERY OBJECT:
def __init__(self, ...) -> None:
....
if not USE_SLOW_COMPARE:
# Precompile a matcher function which compares a List[bytes] against the specified key
# ranges. We do this because matchesEncodedKeys() is often called in a tight inner
# loop, and its performance therefore matters a lot.
comparators = []
for index, keyInfo in enumerate(self._keys):
keyInfo.keyRange.asComparators(index, comparators)
self._encodedKeyMatcher = makeFastComparator(len(self._keys), comparators)
else:
self._encodedKeyMatcher = self._match
def _match(self, *encodedKeys: bytes) -> bool:
"""The slow, default implementation."""
return all(
keyInfo.keyRange.matches(encodedKey)
for keyInfo, encodedKey in zip(self._keys, encodedKeys)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment