Skip to content

Instantly share code, notes, and snippets.

@pkhuong
Last active April 13, 2024 14:36
Show Gist options
  • Star 13 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save pkhuong/69b457af82eeb2bc3ebf1a3e4209ae69 to your computer and use it in GitHub Desktop.
Save pkhuong/69b457af82eeb2bc3ebf1a3e4209ae69 to your computer and use it in GitHub Desktop.
A minimal version of Yannakakis's algorithm for mostly plain Python
#!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23{(.*)$|<details><summary>\1</summary>\n| -e s|^\x20{4}\x23\x23}$|\n</details>| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|
license, imports
# Yannakakis.py by Paul Khuong
#
# To the extent possible under law, the person who associated CC0 with
# Yannakakis.py has waived all copyright and related or neighboring rights
# to Yannakakis.py.
#
# You should have received a copy of the CC0 legalcode along with this
# work.  If not, see <http://creativecommons.org/publicdomain/zero/1.0/>.

import collections.abc

Linear-time analytical queries in plain Python

I didn't know Mihalis Yannakakis is about to turn 70, but that's a fun coïncidence.

This short hack shows how Yannakakis's algorithm lets us implement linear-time (wrt the number of input data rows) analytical queries in regular programming languages, without offloading joins to a specialised database query language, thus avoiding the associated impedance mismatch. There are restrictions on the queries we can express -- Yannakakis's algorithm relies on a hypertree width of 1, and on having hypertree decomposition as a witness -- but that's kind of reasonable: a fractional hypertree width > 1 would mean there are databases for which the intermediate results could superlinearly larger than the input database due to the AGM bound. The hypertree decomposition witness isn't a burden either: structured programs naturally yield a hypertree decomposition, unlike more declarative logic programs that tend to hide the structure implicit in the programmer's thinking.

The key is to mark function arguments as either negative (true inputs) or positive (possible interesting input values derived from negative arguments). In this hack, closed over values are positive, and the only negative argument is the current data row.

We also assume that these functions are always used in a map/reduce pattern, and thus we only memoise the result of map_reduce(function, input), with a group-structured reduction: the reduce function must be associative and commutative, and there must be a zero (neutral) value.

With these constraints, we can express joins in natural Python without incurring the poor runtime scaling of the nested loops we actually wrote. This Python file describes the building blocks to handle aggregation queries like the following

>>> id_skus = [(1, 2), (2, 2), (1, 3)]
>>> sku_costs = [(1, 10), (2, 20), (3, 30)]
>>> def sum_odd_or_even_skus(mod_two):
...     @map_reduce.over(id_skus, Sum())
...     def count_if_mod_two(id_sku):
...         id, sku = id_sku
...         if id % 2 == mod_two:
...             @map_reduce.over(sku_costs, Min(0))
...             def min_cost(sku_cost):
...                 if sku_cost[0] == sku:
...                     return Min(sku_cost[1])
...             return Sum(min_cost)
...     return count_if_mod_two
...

with linear scaling in the length of id_skus and sku_costs, and caching for similar queries.

At a small scale, everything's fast.

>>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
20
0.0007169246673583984
>>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
50
0.0002627372741699219

As we increase the scale by 1000x for both input lists, the runtime scales (sub) linearly for the first query, and is unchanged for the second:

>>> id_skus = id_skus * 1000
>>> sku_costs = sku_costs * 1000
>>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
20000
0.09455370903015137
>>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
50000
0.00025773048400878906

This still pretty much holds up when we multiply by another factor of 100:

>>> id_skus = id_skus * 100
>>> sku_costs = sku_costs * 100
>>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
2000000
6.946590185165405
>>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
5000000
0.00025200843811035156

The magic behind the curtains is memoisation (unsurprisingly), but a special implementation that can share work for similar closures: the memoisation key consists of the function without closed over bindings and the call arguments, while the memoised value is a data structure from the tuple of closed over values to the map_reduce_over output.

This concrete representation of the function as a branching program is the core of Yannakakis's algorithm: we'll iterate over each datum in the input, run the function on it with logical variables instead of the closed over values, and generate a mapping from closed over values to result for all non-zero results. We'll then merge the mappings for all input data together (there is no natural ordering here, hence the group structure).

The output data structure controls the join we can implement. We show how a simple stack of nested key-value mappings handles equijoins, but a k-d range tree would handle inequalities, i.e., "theta" joins (the rest of the machinery already works in terms of less than/greater than constraints).

As long as we explore a bounded number of paths for each datum and a bounded number of function, input cache keys, we'll spend a bounded amount of time on each input datum, and thus linear time total. The magic of Yannakakis's algorithm is that this works even when there are nested map_reduce calls, which would naïvely result in polynomial time (degree equal to the nesting depth).

Memoising through a Python function's closed over values

Even if functions were hashable and comparable for extensional equality, directly using closures as memoisation keys in calls like map_reduce(function, input) would result in superlinear runtime for nested map_reduce calls.

This extract_function_state accepts a function (with or without closed over state), and returns four values:

  1. The underlying code object
  2. The tuple of closed over values (current value for mutable cells)
  3. A function to rebind the closure with new closed over values
  4. The name of the closed over bindings

The third return value, the rebind function, silently fails on complicated cases; this is a hack, after all. In short, it only handles closing over immutable atomic values like integers or strings, but not, e.g., functions (yet), or mutable bindings.

def extract_function_state(function):
    """Accepts a function object and returns information about it: a
    hash key for the object, a tuple of closed over values, a function
    to return a fresh closure with a different tuple of closed over
    values, and a closure of closed over names

    >>> def test(x): return lambda y: x == y
    >>> extract_function_state(test)[1]
    ()
    >>> extract_function_state(test)[3]
    ()
    >>> fun = test(4)
    >>> extract_function_state(fun)[1]
    (4,)
    >>> extract_function_state(fun)[3]
    ('x',)
    >>>
    >>> fun(4)
    True
    >>> fun(5)
    False
    >>> rebind = extract_function_state(fun)[2]
    >>> rebound_4, rebound_5 = rebind([4]), rebind([5])
    >>> rebound_4(4)
    True
    >>> rebound_4(5)
    False
    >>> rebound_5(4)
    False
    >>> rebound_5(5)
    True
    """
    code = function.__code__
    names = code.co_freevars

    if function.__closure__ is None:  # Toplevel function
        assert names == ()

        def rebind(values):
            if len(values) != 0:
                raise RuntimeError(
                    f"Values must be empty for toplevel function. values={values}"
                )
            return function

        return code, (), rebind, names

    closure = tuple(cell.cell_contents for cell in function.__closure__)
    assert len(names) == len(closure), (closures, names)

    # TODO: rebind recursively (functions are also cells)
    def rebind(values):
        if len(values) != len(names):
            raise RuntimeError(
                f"Values must match names. names={names} values={values}"
            )
        return function.__class__(
            code,
            function.__globals__,
            function.__name__,
            function.__defaults__,
            tuple(
                cell.__class__(value)
                for cell, value in zip(function.__closure__, values)
            ),
        )

    return code, closure, rebind, names

Logical variables for closed-over values

We wish to enumerate the support of a function call (parameterised over closed over values), and the associated result value. We'll do that by rebinding the closure to point at instances of OpaqueValue and enumerating all the possible constraints on these OpaqueValues. These OpaqueValues work like logical variables that let us run a function in reverse: when get a non-zero (non-None) return value, we look at the accumulated constraint set on the opaque values and use them to update the data representation of the function's result (we assume that we can represent the constraints on all OpaqueValues in our result data structure).

Currently, we only support nested dictionaries, so each OpaqueValues must be either fully unconstrained (wildcard that matches any value), or constrained to be exactly equal to a value. There's no reason we can't use k-d range trees though, and it's not harder to track a pair of bounds (lower and upper) than a set of inequalities, so we'll handle the general ordered OpaqueValue case.

In the input program (the query), we assume closed over values are only used for comparisons (equality, inequality, relational operators, or conversion to bool, i.e., non-zero testing). Knowing the result of each (non-redundant) comparison tightens the range of potential values for the OpaqueValue... eventually down to a single point value that our hash-based indexes can handle.

Of course, if a comparison isn't redundant, there are multiple feasible results, so we need an external oracle to pick one. An external caller is responsible for injecting its logic as OpaqueValue.CMP_HANDLER, and driving the exploration of the search space.

N.B., the set of constraints we can handle is determined by the ground data structure to represent finitely supported functions.

class OpaqueValue:
    """An opaque value is a one-dimensional range of Python values,
    represented as a lower and an upper bound, each of which is
    potentially exclusive.

    `OpaqueValue`s are only used in queries for comparisons with
    ground values.  All comparisons are turned into three-way
    `__cmp__` calls; non-redundant `__cmp__` calls (which could return
    more than one value) are resolved by calling `CMP_HANDLER`
    and tightening the bound in response.

    >>> x = OpaqueValue("x")
    >>> x == True
    True
    >>> 1 if x else 2
    1
    >>> x.reset()
    >>> ### Not supported by our index data structure (yet)
    >>> # >>> x > 4
    >>> # False
    >>> # >>> x < 4
    >>> # False
    >>> # >>> x == 4
    >>> # True
    >>> # >>> x < 10
    >>> # True
    >>> # >>> x >= 10
    >>> # False
    >>> # >>> x > -10
    >>> # True
    >>> # >>> x <= -10
    >>> # False
    """

    # Resolving function for opaque values
    CMP_HANDLER = lambda opaque, value: 0

    def __init__(self, name):
        self.name = name
        # Upper and lower bounds
        self.lower = self.upper = None
        self.lowerExcl = self.upperExcl = True

    def __str__(self):
        left = "(" if self.lowerExcl else "["
        right = ")" if self.upperExcl else "]"
        return f"<OpaqueValue {self.name} {left}{self.lower}, {self.upper}{right}>"

    def reset(self):
        """Clears all accumulated constraints on this `OpaqueValue`."""
        self.lower = self.upper = None
        self.lowerExcl = self.upperExcl = True

    def indefinite(self):
        """Returns whether this `OpaqueValue` is still unconstrained."""
        return self.lower == None and self.upper == None

    def definite(self):
        """Returns whether is `OpaqueValue` is constrained to an exact value."""
        return (
            self.lower == self.upper
            and self.lower is not None
            and not self.lowerExcl
            and not self.upperExcl
        )

    def value(self):
        """Returns the exact value for this `OpaqueValue`, assuming there is one."""
        return self.lower

    def _contains(self, value, strictly):
        """Returns whether this `OpaqueValue`'s range includes `value`
        (maybe `strictly` inside the range).

        The first value is the containment truth value, and the second
        is the forced `__cmp__` value, if the `value` is *not*
        (strictly) contained in the range.
        """
        if self.lower is not None:
            if self.lower > value:
                return False, 1
            if self.lower == value and (strictly or self.lowerExcl):
                return False, 1
        if self.upper is not None:
            if self.upper < value:
                return False, -1
            if self.upper == value and (strictly or self.upperExcl):
                return False, -1
        return True, 0 if self.definite() else None

    def contains(self, value, strictly=False):
        """Returns whether `values` is `strictly` contained in the
        `OpaqueValue`'s range.
        """
        try:
            return self._contains(value, strictly)[0]
        except TypeError:
            return False

    def potential_mask(self, other):
        """Returns the set of potential `__cmp__` values for `other`
        that are compatible with the current range: bit 0 is set if
        `-1` is possible, bit 1 if `0` is possible, and bit 2 if `1`
        is possible.
        """
        if not self.contains(other):
            return 0
        if self.contains(other, strictly=True):
            return 7
        if self.definite() and self.value == other:
            return 2
        # We have a non-strict inclusion, and inequality.
        if self.lower == other and self.lowerExcl:
            return 6
        assert self.upper == other and self.upperExcl
        return 3

    def __cmp__(self, other):
        """Three-way comparison between this `OpaqueValue` and `other`.

        When the result is known from the current bound, we just return
        that value.  Otherwise, we ask `CMP_HANDLER` what value to return
        and update the bound accordingly.
        """
        if isinstance(other, OpaqueValue) and self.definite() and not other.definite():
            # If we have a definite value and `other` is an indefinite
            # `OpaqueValue`, flip the comparison order to let the `other`
            # argument be a ground value.
            return -other.__cmp__(self.value())

        if isinstance(other, OpaqueValue) and not other.definite():
            raise RuntimeError(
                f"OpaqueValue may only be compared with ground values. self={self} other={other}"
            )
        if isinstance(other, OpaqueValue):
            other = other.value()  # Make sure `other` is a ground value

        if other is None:
            # We use `None` internally, and it doesn't compare well
            raise RuntimeError("OpaqueValue may not be compared with None")

        compatible, order = self._contains(other, False)
        if order is not None:
            return order
        order = OpaqueValue.CMP_HANDLER(self, other)
        if order < 0:
            self._add_bound(upper=other, upperExcl=True)
        elif order == 0:
            self._add_bound(lower=other, lowerExcl=False, upper=other, upperExcl=False)
        else:
            self._add_bound(lower=other, lowerExcl=True)
        return order

    def _add_bound(self, lower=None, lowerExcl=False, upper=None, upperExcl=False):
        """Updates the internal range for this new bound."""
        assert lower is None or self.contains(lower, strictly=lowerExcl)
        assert upper is None or self.contains(upper, strictly=upperExcl)
        if lower is not None:
            self.lower = lower
            self.lowerExcl = lowerExcl

        assert upper is None or self.contains(upper, strictly=upperExcl)
        if upper is not None:
            self.upper = upper
            self.upperExcl = upperExcl

    def __bool__(self):
        return self != 0

    def __eq__(self, other):
        return self.__cmp__(other) == 0

    def __ne__(self, other):
        return self.__cmp__(other) != 0

    # No other comparator because we don't index ranges
    # (no range tree).

    # def __lt__(self, other):
    #     return self.__cmp__(other) < 0

    # def __le__(self, other):
    #     return self.__cmp__(other) <= 0

    # def __gt__(self, other):
    #     return self.__cmp__(other) > 0

    # def __ge__(self, other):
    #     return self.__cmp__(other) >= 0

Depth-first exploration of a function call's support

We assume a None result represents a zero value wrt the aggregate merging function (e.g., 0 for a sum). For convenience, we also treat tuples and lists of Nones identically.

We simply maintain a stack of CMP_HANDLER calls, where each entry in the stack consists of an OpaqueValue and bitset of CMP_HANDLER results still to explore (-1, 0, or 1). This stack is filled on demand, and CMP_HANDLER returns the first result allowed by the bitset.

Once we have a result, we tweak the stack to force depth-first exploration of a different part of the solution space: we drop the first bit in the bitset of results to explore, and drop the entry wholesale if the bitset is now empty (all zero). When this tweaking leaves an empty stack, we're done.

This ends up enumerating all the paths through the function call with a non-recursive depth-first traversal.

We then do the same for each datum in our input sequence, and merge results for identical keys together.

def is_zero_result(value):
    """Checks if `value` is a "zero" aggregate value: either `None`,
    or an iterable of all `None`.

    >>> is_zero_result(None)
    True
    >>> is_zero_result(False)
    False
    >>> is_zero_result(True)
    False
    >>> is_zero_result(0)
    False
    >>> is_zero_result(-10)
    False
    >>> is_zero_result(1.5)
    False
    >>> is_zero_result("")
    False
    >>> is_zero_result("asd")
    False
    >>> is_zero_result((None, None))
    True
    >>> is_zero_result((None, 1))
    False
    >>> is_zero_result([])
    True
    >>> is_zero_result([None])
    True
    >>> is_zero_result([None, (None, None)])
    False
    """
    if value is None:
        return True
    if isinstance(value, (tuple, list)):
        return all(item is None for item in value)
    return False


def enumerate_opaque_values(function, values):
    """Explores the set of `OpaqueValue` constraints when calling
    `function`.

    Enumerates all constraints for the `OpaqueValue` instances in
    `values`, and yields a pair of equality constraints for the
    `value` and the corresponding result, for all non-zero results.

    This essentially turns `function()` into a branching program on
    `values`.

    >>> x, y = OpaqueValue("x"), OpaqueValue("y")
    >>> list(enumerate_opaque_values(lambda: 1 if x == 0 else (2 if x == 1 and y == 2 else None), [x, y]))
    [((0, None), 1), ((1, 2), 2)]

    """
    explorationStack = []  # List of (value, bitmaskOfCmp)
    while True:
        for value in values:
            value.reset()

        stackIndex = 0

        def handle(value, other):
            nonlocal stackIndex
            if len(explorationStack) == stackIndex:
                explorationStack.append((value, value.potential_mask(other)))

            expectedValue, mask = explorationStack[stackIndex]
            assert value is expectedValue
            assert mask != 0

            if (mask & 1) != 0:
                ret = -1
            elif (mask & 2) != 0:
                ret = 0
            elif (mask & 4) != 0:
                ret = 1
            else:
                assert False, f"bad mask {mask}"

            stackIndex += 1
            return ret

        OpaqueValue.CMP_HANDLER = handle
        result = function()
        if not is_zero_result(result):
            for value in values:
                assert (
                    value.definite() or value.indefinite()
                ), f"partially constrained {value} temporarily unsupported"
            yield (tuple(key.value() if key.definite() else None for key in values),
                   result)

        # Drop everything that was fully explored, then move the next
        # top of stack to the next option.
        while explorationStack:
            value, mask = explorationStack[-1]
            assert 0 <= mask < 8

            mask &= mask - 1  # Drop first bit
            if mask != 0:
                explorationStack[-1] = (value, mask)
                break
            explorationStack.pop()
        if not explorationStack:
            break


def enumerate_supporting_values(function, args):
    """Lists the bag of mapping from closed over values to non-zero result,
    for all calls `function(arg) for args in args`.

    >>> def count_eql(needle): return lambda x: 1 if x == needle else None
    >>> list(enumerate_supporting_values(count_eql(4), [1, 2, 4, 4, 2]))
    [((1,), 1), ((2,), 1), ((4,), 1), ((4,), 1), ((2,), 1)]
    """
    _, _, rebind, names = extract_function_state(function)
    values = [OpaqueValue(name) for name in names]
    reboundFunction = rebind(values)
    for arg in args:
        yield from enumerate_opaque_values(lambda: reboundFunction(arg), values)

Type driven merges

The interesting part of map/reduce is the reduction step. While some like to use first-class functions to describe reduction, in my opinion, it often makes more sense to define reduction at the type level: it's essential that merge operators be commutative and associative, so isolating the merge logic in dedicated classes makes sense to me.

This file defines a two trivial mergeable value type, Sum and Min, but we could have different ones, e.g., hyperloglog unique counts, or streaming statistical moments... or even a list of row ids.

class Sum:
    """A counter for summed values."""

    def __init__(self, value=0):
        self.value = value

    def merge(self, other):
        assert isinstance(other, Sum)
        self.value += other.value


class Min:
    """A running `min` value tracker."""

    def __init__(self, value=None):
        self.value = value

    def merge(self, other):
        assert isinstance(other, Min)
        if not self.value:
            self.value = other.value
        elif other.value:
            self.value = min(self.value, other.value)

Nested dictionary with wildcard

There's a direct relationship between the data structure we use to represent the result of function calls as branching functions, and the constraints we can support on closed over values for non-zero results.

In a real implementation, this data structure would host most of the complexity: it's the closest thing we have to indexes.

For now, support equality with ground value as our only constraint. This means the only two cases we must look for at each level are an exact match, or a wildcard match.

We do have to check for both cases at each level, so the worst-case complexity for lookups is exponential in the depth (number of join variables). That's actually reasonable because we don't expect too many join variables (compare to range trees that are also exponential in the number of dimensions... but with a base of log(n) instead of 2).

A real implementation could maybe save work by memoising merged results for internal subtrees... it's tempting to broadcast wildcard values to the individual keyed entries, but I think that might explode the time complexity of our pre-computation phase.

class NestedDictLevel:
    """One level in a nested dictionary index.  We may have a value
    for everything (leaf node), or a key-value dict *and a wildcard
    entry* for a specific index in the tuple key.
    """

    def __init__(self, depth):
        self.depth = depth
        self.value = None
        self.wildcard = None
        self.dict = dict()

    def visit(self, keys, visitor):
        """Passes the values for `keys` to `visitor`."""
        if self.value is not None:
            assert self.wildcard is None and not self.dict
            visitor(self.value)

        if self.depth >= len(keys):
            return

        if self.wildcard is not None:
            self.wildcard.visit(keys, visitor)

        next = self.dict.get(keys[self.depth], None)
        if next is not None:
            next.visit(keys, visitor)

    def set(self, keys, mergeFunction, depth=0):
        """Sets the value for `keys` in this level."""
        assert depth <= len(keys)
        if depth == len(keys):  # Leaf
            self.value = mergeFunction(self.value)
            return

        assert self.depth == depth
        key = keys[depth]
        
        if key is None:
            if self.wildcard is None:
                self.wildcard = NestedDictLevel(depth + 1)
            dst = self.wildcard
        else:
            dst = self.dict.get(key)
            if dst is None:
                dst = NestedDictLevel(depth + 1)
                self.dict[key] = dst
        dst.set(keys, mergeFunction, depth + 1)


class NestedDict:
    """A nested dict of a given `depth` maps tuples of `depth` keys to
    a value.  Each `NestedDictLevel` handles a different level.
    """

    def __init__(self, length):
        self.top = NestedDictLevel(0)
        self.length = length

    def visit(self, keys, visitor):
        """Gets the value associated with `keys`, or `default` if None."""
        assert len(keys) == self.length
        assert all(key is not None for key in keys)
        self.top.visit(keys, visitor)

    def set(self, keys, mergeFn):
        """Sets the value associated with `((index, key), ...)`."""
        assert len(keys) == self.length
        self.top.set(keys, mergeFn)

Identity key-value maps

class IdMap:
    def __init__(self):
        self.entries = dict()  # tuple of id -> (key, value)
        # the value's first element keeps the ids stable.

    def get(self, keys, default=None):
        ids = tuple(id(key) for key in keys)
        return self.entries.get(ids, (None, default))[1]

    def __contains__(self, keys):
        ids = tuple(id(key) for key in keys)
        return ids in self.entries

    def __getitem__(self, keys):
        ids = tuple(id(key) for key in keys)
        return self.entries[ids][1]

    def __setitem__(self, keys, value):
        ids = tuple(id(key) for key in keys)
        self.entries[ids] = (keys, value)

Cached map_reduce

As mentioned earlier, we assume reduce is determined implicitly by the reduced values' type. We also have enumerate_supporting_values to find all the closed over values that yield a non-zero result, for all values in a sequence.

We can thus accept a function and an input sequence, find the supporting values, and merge the result associated with identical supporting values.

Again, we only support ground equality constraints (see assertion on L568), i.e., only equijoins. There's nothing that stops a more sophisticated implementation from using range trees to support inequality or range joins.

We'll cache the precomputed values by code object (i.e., function without closed over values) and input sequence. If we don't have a precomputed value, we'll use enumerate_supporting_values to run the function backward for each input datum from the sequence, and accumulate the results in a NestedDict. Working backward to find closure values that yield a non-zero result (for each input datum) lets us precompute a branching program that directly yields the result. We represent these branching programs explicitly, so we can also directly update a branching program for the result of merging all the values returned by mapping over the input sequence, for a given closure.

This last map_reduce definition ties everything together, and I think is really the general heart of Yannakakis's algorithm as an instance of bottom-up dynamic programming.

def _merge(dst, update):
    if dst is None:
        return update

    if isinstance(dst, (tuple, list)):
        assert len(dst) == len(update)
        for value, new in zip(dst, update):
            value.merge(new)
    else:
        dst.merge(update)
    return dst


def _extractValues(accumulator):
    if accumulator is None:
        return None
    if isinstance(accumulator, tuple):
        return tuple(item.value for item in accumulator)
    if isinstance(accumulator, list):
        return list(item.value for item in accumulator)

    return accumulator.value

def _precompute_map_reduce(function, depth, inputIterable):
    """Given a function (a closure), the number of values the function
    closes over, and an input iterable, generates a `NestedDict`
    representation for `reduce(map(function, inputIterable))`, where
    the reduction step simply calls `merge` on the return values
    (tuples are merged elementwise), and the `NestedDict` keys
    represent closed over values.

    >>> def count_eql(needle): return lambda x: Sum(1) if x == needle else None
    >>> nd = _precompute_map_reduce(count_eql(4), 1, [1, 2, 4, 4, 2])
    >>> nd.visit((0,), lambda sum: print(sum.value))
    >>> nd.visit((1,), lambda sum: print(sum.value))
    1
    >>> nd.visit((2,), lambda sum: print(sum.value))
    2
    >>> nd.visit((4,), lambda sum: print(sum.value))
    2
    """
    cache = NestedDict(depth)
    for indexKeyValues, result in enumerate_supporting_values(function, inputIterable):
        cache.set(indexKeyValues, lambda old: _merge(old, result))
    return cache


AGGREGATE_CACHE = IdMap()  # Map from function, input sequence -> NestedDict


def map_reduce(function, inputIterable, initialValue=None, *, extractResult=True):
    """Returns the result of merging `map(function, inputIterable)`
    into `initialValue`.

    `None` return values represent neutral elements (i.e., the result
    of mapping an empty `inputIterable`), and values are otherwise
    reduced by calling `merge` on a mutable accumulator.

    Assuming `function` is well-behaved, `map_reduce` runs in time
    linear wrt `len(inputIterable)`.  It's also always cached on a
    composite key that consists of the `function`'s code object (i.e.,
    without closed over values) and the `inputIterable`.

    These complexity guarantees let us nest `map_reduce` with
    different closed over values, and still guarantee a linear-time
    total complexity.

    This wrapper ties together all the components

    >>> INVOCATION_COUNTER = 0
    >>> data = (1, 2, 2, 4, 2, 4)
    >>> def count_eql(needle):
    ...     def count(x):
    ...         global INVOCATION_COUNTER
    ...         INVOCATION_COUNTER += 1
    ...         return Sum(x) if x == needle else None
    ...     return count
    >>> INVOCATION_COUNTER
    0
    >>> map_reduce(count_eql(4), data, Sum(), extractResult=False).value
    8
    >>> INVOCATION_COUNTER
    18
    >>> map_reduce(count_eql(2), data)
    6
    >>> INVOCATION_COUNTER
    18
    >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    >>> def sku_min_cost(sku):
    ...     return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0))
    >>> def sum_odd_or_even_skus(mod_two):
    ...     def count_if_mod_two(id_sku):
    ...         id, sku = id_sku
    ...         if id % 2 == mod_two:
    ...             return Sum(sku_min_cost(sku))
    ...     return map_reduce(count_if_mod_two, id_skus, Sum())
    >>> sum_odd_or_even_skus(0)
    20
    >>> sum_odd_or_even_skus(1)
    50

    """
    assert isinstance(inputIterable, collections.abc.Iterable)
    assert not isinstance(inputIterable, collections.abc.Iterator)
    code, closure, *_ = extract_function_state(function)
    if (code, inputIterable) not in AGGREGATE_CACHE:
        AGGREGATE_CACHE[code, inputIterable] = _precompute_map_reduce(
            function, len(closure), inputIterable
        )

    acc = [initialValue]
    def visitor(result):
        acc[0] = _merge(acc[0], result)

    AGGREGATE_CACHE[code, inputIterable].visit(closure, visitor)
    return _extractValues(acc[0]) if extractResult else acc[0]


map_reduce.over = \
    lambda inputIterable, initialValue=None, *, extractResult=True: \
        lambda fn: map_reduce(fn, inputIterable, initialValue, extractResult=extractResult)

 
if __name__ == "__main__":
    import doctest

    doctest.testmod()

Is this actually a DB post?

Although the intro name-dropped Yannakakis, the presentation here has a very programming language / logic programming flavour. I think the logic programming point of view, where we run a program backwards with logical variables, is much clearer than the specific case of conjunctive equijoin queries in the usual presentation of Yannakakis's algorithm. In particular, I think there's a clear path to handle range or comparison joins: it's all about having an index data structure to handle range queries.

It should be clear how to write conjunctive queries as Python functions, given a hypertree decomposition. The reverse is much more complex, if only because Python is much more powerful than just CQ, and that's actually a liability: this hack will blindly try to convert any function to a branching program, instead of giving up noisily when the function is too complex.

The other difference from classical CQs is that we focus on aggregates. That's because aggregates are the more general form: if we just want to avoid useless work while enumerating all join rows, we only need a boolean aggregate that tells us whether the join will yield at least one row. We could also special case types for which merges don't save space (e.g., set of row ids), and instead enumerate values by walking the branching program tree.

The aggregate viewpoint also works for fun extensions like indexed access to ranked results: that extension ends up counting the number of output values up to a certain key.

I guess, in a way, we just showed a trivial way to decorrelate queries with a hypertree-width of 1. We just have to be OK with building one index for each loop in the nest... but it should be possible to pattern match on pre-defined indexes and avoid obvious redundancy.

Extensions and future work

Use a dedicated DSL

First, the whole idea of introspecting closures to stub in logical variable is a terrible hack (looks cool though ;). A real production implementation should apply CPS partial evaluation to a purely functional programming language, then bulk reverse-evaluate with a SIMD implementation of the logical program.

There'll be restrictions on the output traces, but that's OK: a different prototype makes me believe the restrictions correspond to deterministic logspace (L), and it makes sense to restrict our analyses to L. Just like grammars are easier to work with when restricted to LL(1), DSLs that only capture L tend to be easier to analyse and optimise... and L is reasonably larger (a polynomial-time algorithm that's not in L would be a huge result).

Handle local functions

While we sin with the closure hack (extract_function_state) it should really be extended to cover local functions. This is mostly a question of going deeply into values that are mapped to functions, and of maintaining an id-keyed map from cell to OpaqueValue.

We could also add support for partial application objects, which may be easier for multiprocessing.

Parallelism

There is currently no support for parallelism, only caching. It should be easy to handle the return values (NestedDicts and aggregate classes like Sum or Min). Distributing the work in _precompute_map_reduce to merge locally is also not hard.

The main issue with parallelism is that we can't pass functions as work units, so we'd have to stick to the fork process pool.

There's also no support for moving (child) work forward when blocked waiting on a future. We'd have to spawn workers on the fly to oversubscribe when workers are blocked on a result (spawning on demand is already a given for fork workers), and to implement our own concurrency control to avoid wasted work, and probably internal throttling to avoid thrashing when we'd have more active threads than cores.

That being said, the complexity is probably worth the speed up on realistic queries.

Theta joins

At a higher level, we could support comparison joins (e.g., less than, greater than or equal, in range) if only we represented the branching programs with a data structure that supported these queries. A range tree would let us handle these "theta" joins, for tbe low low cost of a polylogarithmic multiplicative factor in space and time.

Self-adjusting computation

Finally, we could update the indexed branching programs incrementally after small changes to the input data. This might sound like a job for streaming engines like timely dataflow, but I think viewing each _precompute_map_reduce call as a purely functional map/reduce job gives a better fit with self-adjusting computation.

Once we add logic to recycle previously constructed indexes, it will probably make sense to allow an initial filtering step before map/reduce, with a cache key on the filter function (with closed over values and all). We can often implement the filtering more efficiently than we can run functions backward, and we'll also observe that slightly different filter functions often result in not too dissimilar filtered sets. Factoring out this filtering can thus enable more reuse of partial precomputed results.

#!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23{(.*)$|<details><summary>\1</summary>\n| -e s|^\x20{4}\x23\x23}$|\n</details>| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|
##{license, imports
# Yannakakis.py by Paul Khuong
#
# To the extent possible under law, the person who associated CC0 with
# Yannakakis.py has waived all copyright and related or neighboring rights
# to Yannakakis.py.
#
# You should have received a copy of the CC0 legalcode along with this
# work. If not, see <http://creativecommons.org/publicdomain/zero/1.0/>.
import collections.abc
##}
## # Linear-time analytical queries in plain Python
##
## <small>I didn't know [Mihalis Yannakakis is about to turn 70](https://mihalisfest.cs.columbia.edu), but that's a fun coïncidence.</small>
##
## This short hack shows how [Yannakakis's algorithm](https://pages.cs.wisc.edu/~paris/cs784-f19/lectures/lecture4.pdf)
## lets us implement linear-time (wrt the number of input data rows)
## analytical queries in regular programming languages, without
## offloading joins to a specialised database query language, thus
## avoiding the associated impedance mismatch. There are restrictions
## on the queries we can express -- Yannakakis's algorithm relies on a
## hypertree width of 1, and on having hypertree decomposition as a witness
## -- but that's kind of reasonable: a fractional hypertree width > 1
## would mean there are databases for which the intermediate results
## could superlinearly larger than the input database due to the [AGM bound](https://arxiv.org/abs/1711.03860).
## The hypertree decomposition witness isn't a burden either:
## structured programs naturally yield a hypertree decomposition,
## unlike more declarative logic programs that tend to hide the
## structure implicit in the programmer's thinking.
##
## The key is to mark function arguments as either negative (true
## inputs) or positive (possible interesting input values derived from
## negative arguments). In this hack, closed over values are
## positive, and the only negative argument is the current data row.
##
## We also assume that these functions are always used in a map/reduce
## pattern, and thus we only memoise the result of
## `map_reduce(function, input)`, with a group-structured reduction:
## the reduce function must be associative and commutative, and there
## must be a zero (neutral) value.
##
## With these constraints, we can express joins in natural Python
## without incurring the poor runtime scaling of the nested loops we
## actually wrote. This Python file describes the building blocks
## to handle aggregation queries like the following
##
## >>> id_skus = [(1, 2), (2, 2), (1, 3)]
## >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
## >>> def sum_odd_or_even_skus(mod_two):
## ... @map_reduce.over(id_skus, Sum())
## ... def count_if_mod_two(id_sku):
## ... id, sku = id_sku
## ... if id % 2 == mod_two:
## ... @map_reduce.over(sku_costs, Min(0))
## ... def min_cost(sku_cost):
## ... if sku_cost[0] == sku:
## ... return Min(sku_cost[1])
## ... return Sum(min_cost)
## ... return count_if_mod_two
## ...
##
## with linear scaling in the length of `id_skus` and `sku_costs`, and
## caching for similar queries.
##
## At a small scale, everything's fast.
##
## >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
## 20
## 0.0007169246673583984
## >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
## 50
## 0.0002627372741699219
##
## As we increase the scale by 1000x for both input lists, the runtime
## scales (sub) linearly for the first query, and is unchanged for the
## second:
##
## >>> id_skus = id_skus * 1000
## >>> sku_costs = sku_costs * 1000
## >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
## 20000
## 0.09455370903015137
## >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
## 50000
## 0.00025773048400878906
##
## This still pretty much holds up when we multiply by another factor of 100:
##
## >>> id_skus = id_skus * 100
## >>> sku_costs = sku_costs * 100
## >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
## 2000000
## 6.946590185165405
## >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
## 5000000
## 0.00025200843811035156
##
## The magic behind the curtains is memoisation (unsurprisingly), but
## a special implementation that can share work for similar closures:
## the memoisation key consists of the function *without closed over
## bindings* and the call arguments, while the memoised value is a
## data structure from the tuple of closed over values to the
## `map_reduce_over` output.
##
## This concrete representation of the function as a branching program
## is the core of Yannakakis's algorithm: we'll iterate over each
## datum in the input, run the function on it with logical variables
## instead of the closed over values, and generate a mapping from
## closed over values to result for all non-zero results. We'll then
## merge the mappings for all input data together (there is no natural
## ordering here, hence the group structure).
##
## The output data structure controls the join we can implement. We
## show how a simple stack of nested key-value mappings handles
## equijoins, but a [k-d range tree](https://dl.acm.org/doi/10.1145/356789.356797)
## would handle inequalities, i.e., "theta" joins (the rest of the
## machinery already works in terms of less than/greater than
## constraints).
##
## As long as we explore a bounded number of paths for each datum and
## a bounded number of `function, input` cache keys, we'll spend a
## bounded amount of time on each input datum, and thus linear time
## total. The magic of Yannakakis's algorithm is that this works even
## when there are nested `map_reduce` calls, which would naïvely
## result in polynomial time (degree equal to the nesting depth).
##{<h2>Memoising through a Python function's closed over values</h2>
##
## Even if functions were hashable and comparable for extensional
## equality, directly using closures as memoisation keys in calls like
## `map_reduce(function, input)` would result in superlinear runtime
## for nested `map_reduce` calls.
##
## This `extract_function_state` accepts a function (with or without
## closed over state), and returns four values:
## 1. The underlying code object
## 2. The tuple of closed over values (current value for mutable cells)
## 3. A function to rebind the closure with new closed over values
## 4. The name of the closed over bindings
##
## The third return value, the `rebind` function, silently fails on
## complicated cases; this is a hack, after all. In short, it only
## handles closing over immutable atomic values like integers or
## strings, but not, e.g., functions (yet), or mutable bindings.
def extract_function_state(function):
"""Accepts a function object and returns information about it: a
hash key for the object, a tuple of closed over values, a function
to return a fresh closure with a different tuple of closed over
values, and a closure of closed over names
>>> def test(x): return lambda y: x == y
>>> extract_function_state(test)[1]
()
>>> extract_function_state(test)[3]
()
>>> fun = test(4)
>>> extract_function_state(fun)[1]
(4,)
>>> extract_function_state(fun)[3]
('x',)
>>>
>>> fun(4)
True
>>> fun(5)
False
>>> rebind = extract_function_state(fun)[2]
>>> rebound_4, rebound_5 = rebind([4]), rebind([5])
>>> rebound_4(4)
True
>>> rebound_4(5)
False
>>> rebound_5(4)
False
>>> rebound_5(5)
True
"""
code = function.__code__
names = code.co_freevars
if function.__closure__ is None: # Toplevel function
assert names == ()
def rebind(values):
if len(values) != 0:
raise RuntimeError(
f"Values must be empty for toplevel function. values={values}"
)
return function
return code, (), rebind, names
closure = tuple(cell.cell_contents for cell in function.__closure__)
assert len(names) == len(closure), (closures, names)
# TODO: rebind recursively (functions are also cells)
def rebind(values):
if len(values) != len(names):
raise RuntimeError(
f"Values must match names. names={names} values={values}"
)
return function.__class__(
code,
function.__globals__,
function.__name__,
function.__defaults__,
tuple(
cell.__class__(value)
for cell, value in zip(function.__closure__, values)
),
)
return code, closure, rebind, names
##}
## ## Logical variables for closed-over values
##
## We wish to enumerate the support of a function call (parameterised
## over closed over values), and the associated result value. We'll
## do that by rebinding the closure to point at instances of
## `OpaqueValue` and enumerating all the possible constraints on these
## `OpaqueValue`s. These `OpaqueValue`s work like logical variables
## that let us run a function in reverse: when get a non-zero
## (non-None) return value, we look at the accumulated constraint set
## on the opaque values and use them to update the data representation
## of the function's result (we assume that we can represent the
## constraints on all `OpaqueValue`s in our result data structure).
##
## Currently, we only support nested dictionaries, so each
## `OpaqueValue`s must be either fully unconstrained (wildcard that
## matches any value), or constrained to be exactly equal to a value.
## There's no reason we can't use k-d range trees though, and it's not
## harder to track a pair of bounds (lower and upper) than a set of
## inequalities, so we'll handle the general ordered `OpaqueValue`
## case.
##
## In the input program (the query), we assume closed over values are
## only used for comparisons (equality, inequality, relational
## operators, or conversion to bool, i.e., non-zero testing). Knowing
## the result of each (non-redundant) comparison tightens the range
## of potential values for the `OpaqueValue`... eventually down to a
## single point value that our hash-based indexes can handle.
##
## Of course, if a comparison isn't redundant, there are multiple
## feasible results, so we need an external oracle to pick one. An
## external caller is responsible for injecting its logic as
## `OpaqueValue.CMP_HANDLER`, and driving the exploration of the
## search space.
##
## N.B., the set of constraints we can handle is determined by the
## ground data structure to represent finitely supported functions.
class OpaqueValue:
"""An opaque value is a one-dimensional range of Python values,
represented as a lower and an upper bound, each of which is
potentially exclusive.
`OpaqueValue`s are only used in queries for comparisons with
ground values. All comparisons are turned into three-way
`__cmp__` calls; non-redundant `__cmp__` calls (which could return
more than one value) are resolved by calling `CMP_HANDLER`
and tightening the bound in response.
>>> x = OpaqueValue("x")
>>> x == True
True
>>> 1 if x else 2
1
>>> x.reset()
>>> ### Not supported by our index data structure (yet)
>>> # >>> x > 4
>>> # False
>>> # >>> x < 4
>>> # False
>>> # >>> x == 4
>>> # True
>>> # >>> x < 10
>>> # True
>>> # >>> x >= 10
>>> # False
>>> # >>> x > -10
>>> # True
>>> # >>> x <= -10
>>> # False
"""
# Resolving function for opaque values
CMP_HANDLER = lambda opaque, value: 0
def __init__(self, name):
self.name = name
# Upper and lower bounds
self.lower = self.upper = None
self.lowerExcl = self.upperExcl = True
def __str__(self):
left = "(" if self.lowerExcl else "["
right = ")" if self.upperExcl else "]"
return f"<OpaqueValue {self.name} {left}{self.lower}, {self.upper}{right}>"
def reset(self):
"""Clears all accumulated constraints on this `OpaqueValue`."""
self.lower = self.upper = None
self.lowerExcl = self.upperExcl = True
def indefinite(self):
"""Returns whether this `OpaqueValue` is still unconstrained."""
return self.lower == None and self.upper == None
def definite(self):
"""Returns whether is `OpaqueValue` is constrained to an exact value."""
return (
self.lower == self.upper
and self.lower is not None
and not self.lowerExcl
and not self.upperExcl
)
def value(self):
"""Returns the exact value for this `OpaqueValue`, assuming there is one."""
return self.lower
def _contains(self, value, strictly):
"""Returns whether this `OpaqueValue`'s range includes `value`
(maybe `strictly` inside the range).
The first value is the containment truth value, and the second
is the forced `__cmp__` value, if the `value` is *not*
(strictly) contained in the range.
"""
if self.lower is not None:
if self.lower > value:
return False, 1
if self.lower == value and (strictly or self.lowerExcl):
return False, 1
if self.upper is not None:
if self.upper < value:
return False, -1
if self.upper == value and (strictly or self.upperExcl):
return False, -1
return True, 0 if self.definite() else None
def contains(self, value, strictly=False):
"""Returns whether `values` is `strictly` contained in the
`OpaqueValue`'s range.
"""
try:
return self._contains(value, strictly)[0]
except TypeError:
return False
def potential_mask(self, other):
"""Returns the set of potential `__cmp__` values for `other`
that are compatible with the current range: bit 0 is set if
`-1` is possible, bit 1 if `0` is possible, and bit 2 if `1`
is possible.
"""
if not self.contains(other):
return 0
if self.contains(other, strictly=True):
return 7
if self.definite() and self.value == other:
return 2
# We have a non-strict inclusion, and inequality.
if self.lower == other and self.lowerExcl:
return 6
assert self.upper == other and self.upperExcl
return 3
def __cmp__(self, other):
"""Three-way comparison between this `OpaqueValue` and `other`.
When the result is known from the current bound, we just return
that value. Otherwise, we ask `CMP_HANDLER` what value to return
and update the bound accordingly.
"""
if isinstance(other, OpaqueValue) and self.definite() and not other.definite():
# If we have a definite value and `other` is an indefinite
# `OpaqueValue`, flip the comparison order to let the `other`
# argument be a ground value.
return -other.__cmp__(self.value())
if isinstance(other, OpaqueValue) and not other.definite():
raise RuntimeError(
f"OpaqueValue may only be compared with ground values. self={self} other={other}"
)
if isinstance(other, OpaqueValue):
other = other.value() # Make sure `other` is a ground value
if other is None:
# We use `None` internally, and it doesn't compare well
raise RuntimeError("OpaqueValue may not be compared with None")
compatible, order = self._contains(other, False)
if order is not None:
return order
order = OpaqueValue.CMP_HANDLER(self, other)
if order < 0:
self._add_bound(upper=other, upperExcl=True)
elif order == 0:
self._add_bound(lower=other, lowerExcl=False, upper=other, upperExcl=False)
else:
self._add_bound(lower=other, lowerExcl=True)
return order
def _add_bound(self, lower=None, lowerExcl=False, upper=None, upperExcl=False):
"""Updates the internal range for this new bound."""
assert lower is None or self.contains(lower, strictly=lowerExcl)
assert upper is None or self.contains(upper, strictly=upperExcl)
if lower is not None:
self.lower = lower
self.lowerExcl = lowerExcl
assert upper is None or self.contains(upper, strictly=upperExcl)
if upper is not None:
self.upper = upper
self.upperExcl = upperExcl
def __bool__(self):
return self != 0
def __eq__(self, other):
return self.__cmp__(other) == 0
def __ne__(self, other):
return self.__cmp__(other) != 0
# No other comparator because we don't index ranges
# (no range tree).
# def __lt__(self, other):
# return self.__cmp__(other) < 0
# def __le__(self, other):
# return self.__cmp__(other) <= 0
# def __gt__(self, other):
# return self.__cmp__(other) > 0
# def __ge__(self, other):
# return self.__cmp__(other) >= 0
## ## Depth-first exploration of a function call's support
##
## We assume a `None` result represents a zero value wrt the aggregate
## merging function (e.g., 0 for a sum). For convenience, we also treat
## tuples and lists of `None`s identically.
##
## We simply maintain a stack of `CMP_HANDLER` calls, where each entry
## in the stack consists of an `OpaqueValue` and bitset of `CMP_HANDLER`
## results still to explore (-1, 0, or 1). This stack is filled on demand,
## and `CMP_HANDLER` returns the first result allowed by the bitset.
##
## Once we have a result, we tweak the stack to force depth-first
## exploration of a different part of the solution space: we drop
## the first bit in the bitset of results to explore, and drop the
## entry wholesale if the bitset is now empty (all zero). When
## this tweaking leaves an empty stack, we're done.
##
## This ends up enumerating all the paths through the function call
## with a non-recursive depth-first traversal.
##
## We then do the same for each datum in our input sequence, and merge
## results for identical keys together.
def is_zero_result(value):
"""Checks if `value` is a "zero" aggregate value: either `None`,
or an iterable of all `None`.
>>> is_zero_result(None)
True
>>> is_zero_result(False)
False
>>> is_zero_result(True)
False
>>> is_zero_result(0)
False
>>> is_zero_result(-10)
False
>>> is_zero_result(1.5)
False
>>> is_zero_result("")
False
>>> is_zero_result("asd")
False
>>> is_zero_result((None, None))
True
>>> is_zero_result((None, 1))
False
>>> is_zero_result([])
True
>>> is_zero_result([None])
True
>>> is_zero_result([None, (None, None)])
False
"""
if value is None:
return True
if isinstance(value, (tuple, list)):
return all(item is None for item in value)
return False
def enumerate_opaque_values(function, values):
"""Explores the set of `OpaqueValue` constraints when calling
`function`.
Enumerates all constraints for the `OpaqueValue` instances in
`values`, and yields a pair of equality constraints for the
`value` and the corresponding result, for all non-zero results.
This essentially turns `function()` into a branching program on
`values`.
>>> x, y = OpaqueValue("x"), OpaqueValue("y")
>>> list(enumerate_opaque_values(lambda: 1 if x == 0 else (2 if x == 1 and y == 2 else None), [x, y]))
[((0, None), 1), ((1, 2), 2)]
"""
explorationStack = [] # List of (value, bitmaskOfCmp)
while True:
for value in values:
value.reset()
stackIndex = 0
def handle(value, other):
nonlocal stackIndex
if len(explorationStack) == stackIndex:
explorationStack.append((value, value.potential_mask(other)))
expectedValue, mask = explorationStack[stackIndex]
assert value is expectedValue
assert mask != 0
if (mask & 1) != 0:
ret = -1
elif (mask & 2) != 0:
ret = 0
elif (mask & 4) != 0:
ret = 1
else:
assert False, f"bad mask {mask}"
stackIndex += 1
return ret
OpaqueValue.CMP_HANDLER = handle
result = function()
if not is_zero_result(result):
for value in values:
assert (
value.definite() or value.indefinite()
), f"partially constrained {value} temporarily unsupported"
yield (tuple(key.value() if key.definite() else None for key in values),
result)
# Drop everything that was fully explored, then move the next
# top of stack to the next option.
while explorationStack:
value, mask = explorationStack[-1]
assert 0 <= mask < 8
mask &= mask - 1 # Drop first bit
if mask != 0:
explorationStack[-1] = (value, mask)
break
explorationStack.pop()
if not explorationStack:
break
def enumerate_supporting_values(function, args):
"""Lists the bag of mapping from closed over values to non-zero result,
for all calls `function(arg) for args in args`.
>>> def count_eql(needle): return lambda x: 1 if x == needle else None
>>> list(enumerate_supporting_values(count_eql(4), [1, 2, 4, 4, 2]))
[((1,), 1), ((2,), 1), ((4,), 1), ((4,), 1), ((2,), 1)]
"""
_, _, rebind, names = extract_function_state(function)
values = [OpaqueValue(name) for name in names]
reboundFunction = rebind(values)
for arg in args:
yield from enumerate_opaque_values(lambda: reboundFunction(arg), values)
## ## Type driven merges
##
## The interesting part of map/reduce is the reduction step. While
## some like to use first-class functions to describe reduction, in my
## opinion, it often makes more sense to define reduction at the type
## level: it's essential that merge operators be commutative and
## associative, so isolating the merge logic in dedicated classes
## makes sense to me.
##
## This file defines a two trivial mergeable value type, `Sum` and
## `Min`, but we could have different ones, e.g., hyperloglog unique
## counts, or streaming statistical moments... or even a list of row
## ids.
class Sum:
"""A counter for summed values."""
def __init__(self, value=0):
self.value = value
def merge(self, other):
assert isinstance(other, Sum)
self.value += other.value
class Min:
"""A running `min` value tracker."""
def __init__(self, value=None):
self.value = value
def merge(self, other):
assert isinstance(other, Min)
if not self.value:
self.value = other.value
elif other.value:
self.value = min(self.value, other.value)
## ## Nested dictionary with wildcard
##
## There's a direct relationship between the data structure we use to
## represent the result of function calls as branching functions, and
## the constraints we can support on closed over values for non-zero
## results.
##
## In a real implementation, this data structure would host most of
## the complexity: it's the closest thing we have to indexes.
##
## For now, support equality *with ground value* as our only
## constraint. This means the only two cases we must look for at
## each level are an exact match, or a wildcard match.
##
## We do have to check for both cases at each level, so the worst-case
## complexity for lookups is exponential in the depth (number of join
## variables). That's actually reasonable because we don't expect too
## many join variables (compare to [range trees](https://en.wikipedia.org/wiki/Range_tree#Range_queries)
## that are also exponential in the number of dimensions... but with
## a base of log(n) instead of 2).
##
## A real implementation could maybe save work by memoising merged
## results for internal subtrees... it's tempting to broadcast
## wildcard values to the individual keyed entries, but I think that
## might explode the time complexity of our pre-computation phase.
class NestedDictLevel:
"""One level in a nested dictionary index. We may have a value
for everything (leaf node), or a key-value dict *and a wildcard
entry* for a specific index in the tuple key.
"""
def __init__(self, depth):
self.depth = depth
self.value = None
self.wildcard = None
self.dict = dict()
def visit(self, keys, visitor):
"""Passes the values for `keys` to `visitor`."""
if self.value is not None:
assert self.wildcard is None and not self.dict
visitor(self.value)
if self.depth >= len(keys):
return
if self.wildcard is not None:
self.wildcard.visit(keys, visitor)
next = self.dict.get(keys[self.depth], None)
if next is not None:
next.visit(keys, visitor)
def set(self, keys, mergeFunction, depth=0):
"""Sets the value for `keys` in this level."""
assert depth <= len(keys)
if depth == len(keys): # Leaf
self.value = mergeFunction(self.value)
return
assert self.depth == depth
key = keys[depth]
if key is None:
if self.wildcard is None:
self.wildcard = NestedDictLevel(depth + 1)
dst = self.wildcard
else:
dst = self.dict.get(key)
if dst is None:
dst = NestedDictLevel(depth + 1)
self.dict[key] = dst
dst.set(keys, mergeFunction, depth + 1)
class NestedDict:
"""A nested dict of a given `depth` maps tuples of `depth` keys to
a value. Each `NestedDictLevel` handles a different level.
"""
def __init__(self, length):
self.top = NestedDictLevel(0)
self.length = length
def visit(self, keys, visitor):
"""Gets the value associated with `keys`, or `default` if None."""
assert len(keys) == self.length
assert all(key is not None for key in keys)
self.top.visit(keys, visitor)
def set(self, keys, mergeFn):
"""Sets the value associated with `((index, key), ...)`."""
assert len(keys) == self.length
self.top.set(keys, mergeFn)
##{<h2>Identity key-value maps</h2>
class IdMap:
def __init__(self):
self.entries = dict() # tuple of id -> (key, value)
# the value's first element keeps the ids stable.
def get(self, keys, default=None):
ids = tuple(id(key) for key in keys)
return self.entries.get(ids, (None, default))[1]
def __contains__(self, keys):
ids = tuple(id(key) for key in keys)
return ids in self.entries
def __getitem__(self, keys):
ids = tuple(id(key) for key in keys)
return self.entries[ids][1]
def __setitem__(self, keys, value):
ids = tuple(id(key) for key in keys)
self.entries[ids] = (keys, value)
##}
## ## Cached `map_reduce`
##
## As mentioned earlier, we assume `reduce` is determined implicitly
## by the reduced values' type. We also have
## `enumerate_supporting_values` to find all the closed over values
## that yield a non-zero result, for all values in a sequence.
##
## We can thus accept a function and an input sequence, find the
## supporting values, and merge the result associated with identical
## supporting values.
##
## Again, we only support ground equality constraints (see assertion
## on L568), i.e., only equijoins. There's nothing that stops a more
## sophisticated implementation from using range trees to support
## inequality or range joins.
##
## We'll cache the precomputed values by code object (i.e., function
## without closed over values) and input sequence. If we don't have a
## precomputed value, we'll use `enumerate_supporting_values` to run
## the function backward for each input datum from the sequence, and
## accumulate the results in a `NestedDict`. Working backward to find
## closure values that yield a non-zero result (for each input datum)
## lets us precompute a branching program that directly yields the
## result. We represent these branching programs explicitly, so we
## can also directly update a branching program for the result of
## merging all the values returned by mapping over the input sequence,
## for a given closure.
##
## This last `map_reduce` definition ties everything together, and
## I think is really the general heart of Yannakakis's algorithm
## as an instance of bottom-up dynamic programming.
def _merge(dst, update):
if dst is None:
return update
if isinstance(dst, (tuple, list)):
assert len(dst) == len(update)
for value, new in zip(dst, update):
value.merge(new)
else:
dst.merge(update)
return dst
def _extractValues(accumulator):
if accumulator is None:
return None
if isinstance(accumulator, tuple):
return tuple(item.value for item in accumulator)
if isinstance(accumulator, list):
return list(item.value for item in accumulator)
return accumulator.value
def _precompute_map_reduce(function, depth, inputIterable):
"""Given a function (a closure), the number of values the function
closes over, and an input iterable, generates a `NestedDict`
representation for `reduce(map(function, inputIterable))`, where
the reduction step simply calls `merge` on the return values
(tuples are merged elementwise), and the `NestedDict` keys
represent closed over values.
>>> def count_eql(needle): return lambda x: Sum(1) if x == needle else None
>>> nd = _precompute_map_reduce(count_eql(4), 1, [1, 2, 4, 4, 2])
>>> nd.visit((0,), lambda sum: print(sum.value))
>>> nd.visit((1,), lambda sum: print(sum.value))
1
>>> nd.visit((2,), lambda sum: print(sum.value))
2
>>> nd.visit((4,), lambda sum: print(sum.value))
2
"""
cache = NestedDict(depth)
for indexKeyValues, result in enumerate_supporting_values(function, inputIterable):
cache.set(indexKeyValues, lambda old: _merge(old, result))
return cache
AGGREGATE_CACHE = IdMap() # Map from function, input sequence -> NestedDict
def map_reduce(function, inputIterable, initialValue=None, *, extractResult=True):
"""Returns the result of merging `map(function, inputIterable)`
into `initialValue`.
`None` return values represent neutral elements (i.e., the result
of mapping an empty `inputIterable`), and values are otherwise
reduced by calling `merge` on a mutable accumulator.
Assuming `function` is well-behaved, `map_reduce` runs in time
linear wrt `len(inputIterable)`. It's also always cached on a
composite key that consists of the `function`'s code object (i.e.,
without closed over values) and the `inputIterable`.
These complexity guarantees let us nest `map_reduce` with
different closed over values, and still guarantee a linear-time
total complexity.
This wrapper ties together all the components
>>> INVOCATION_COUNTER = 0
>>> data = (1, 2, 2, 4, 2, 4)
>>> def count_eql(needle):
... def count(x):
... global INVOCATION_COUNTER
... INVOCATION_COUNTER += 1
... return Sum(x) if x == needle else None
... return count
>>> INVOCATION_COUNTER
0
>>> map_reduce(count_eql(4), data, Sum(), extractResult=False).value
8
>>> INVOCATION_COUNTER
18
>>> map_reduce(count_eql(2), data)
6
>>> INVOCATION_COUNTER
18
>>> id_skus = [(1, 2), (2, 2), (1, 3)]
>>> sku_costs = [(1, 10), (2, 20), (3, 30)]
>>> def sku_min_cost(sku):
... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0))
>>> def sum_odd_or_even_skus(mod_two):
... def count_if_mod_two(id_sku):
... id, sku = id_sku
... if id % 2 == mod_two:
... return Sum(sku_min_cost(sku))
... return map_reduce(count_if_mod_two, id_skus, Sum())
>>> sum_odd_or_even_skus(0)
20
>>> sum_odd_or_even_skus(1)
50
"""
assert isinstance(inputIterable, collections.abc.Iterable)
assert not isinstance(inputIterable, collections.abc.Iterator)
code, closure, *_ = extract_function_state(function)
if (code, inputIterable) not in AGGREGATE_CACHE:
AGGREGATE_CACHE[code, inputIterable] = _precompute_map_reduce(
function, len(closure), inputIterable
)
acc = [initialValue]
def visitor(result):
acc[0] = _merge(acc[0], result)
AGGREGATE_CACHE[code, inputIterable].visit(closure, visitor)
return _extractValues(acc[0]) if extractResult else acc[0]
map_reduce.over = \
lambda inputIterable, initialValue=None, *, extractResult=True: \
lambda fn: map_reduce(fn, inputIterable, initialValue, extractResult=extractResult)
if __name__ == "__main__":
import doctest
doctest.testmod()
## ## Is this actually a DB post?
##
## Although the intro name-dropped Yannakakis, the presentation here
## has a very programming language / logic programming flavour. I
## think the logic programming point of view, where we run a program
## backwards with logical variables, is much clearer than the specific
## case of conjunctive equijoin queries in the usual presentation of
## Yannakakis's algorithm. In particular, I think there's a clear
## path to handle range or comparison joins: it's all about having an
## index data structure to handle range queries.
##
## It should be clear how to write conjunctive queries as Python
## functions, given a hypertree decomposition. The reverse is much
## more complex, if only because Python is much more powerful than
## just CQ, and that's actually a liability: this hack will blindly
## try to convert any function to a branching program, instead of
## giving up noisily when the function is too complex.
##
## The other difference from classical CQs is that we focus on
## aggregates. That's because aggregates are the more general form:
## if we just want to avoid useless work while enumerating all join
## rows, we only need a boolean aggregate that tells us whether the
## join will yield at least one row. We could also special case types
## for which merges don't save space (e.g., set of row ids), and
## instead enumerate values by walking the branching program tree.
##
## The aggregate viewpoint also works for
## [fun extensions like indexed access to ranked results](https://ntzia.github.io/download/Tractable_Orders_2020.pdf):
## that extension ends up counting the number of output values up to a
## certain key.
##
## I guess, in a way, we just showed a trivial way to decorrelate
## queries with a hypertree-width of 1. We just have to be OK with
## building one index for each loop in the nest... but it should be
## possible to pattern match on pre-defined indexes and avoid obvious
## redundancy.
##
## ## Extensions and future work
##
## ### Use a dedicated DSL
##
## First, the whole idea of introspecting closures to stub in logical
## variable is a terrible hack (looks cool though ;). A real
## production implementation should apply CPS partial evaluation to a
## purely functional programming language, then bulk reverse-evaluate
## with a SIMD implementation of the logical program.
##
## There'll be restrictions on the output traces, but that's OK: a
## different prototype makes me believe the restrictions correspond to
## deterministic logspace (L), and it makes sense to restrict our
## analyses to L. Just like grammars are easier to work with when
## restricted to LL(1), DSLs that only capture L tend to be easier to
## analyse and optimise... and L is reasonably larger (a
## polynomial-time algorithm that's not in L would be a *huge*
## result).
##
## ### Handle local functions
##
## While we sin with the closure hack (`extract_function_state`) it
## should really be extended to cover local functions. This is mostly
## a question of going deeply into values that are mapped to
## functions, and of maintaining an id-keyed map from cell to
## `OpaqueValue`.
##
## We could also add support for partial application objects, which
## may be easier for multiprocessing.
##
## ### Parallelism
##
## There is currently no support for parallelism, only caching. It
## should be easy to handle the return values (`NestedDict`s and
## aggregate classes like `Sum` or `Min`). Distributing the work in
## `_precompute_map_reduce` to merge locally is also not hard.
##
## The main issue with parallelism is that we can't pass functions
## as work units, so we'd have to stick to the `fork` process pool.
##
## There's also no support for moving (child) work forward when
## blocked waiting on a future. We'd have to spawn workers on the fly
## to oversubscribe when workers are blocked on a result (spawning on
## demand is already a given for `fork` workers), and to implement our
## own concurrency control to avoid wasted work, and probably internal
## throttling to avoid thrashing when we'd have more active threads
## than cores.
##
## That being said, the complexity is probably worth the speed up on
## realistic queries.
##
## ### Theta joins
##
## At a higher level, we could support comparison joins (e.g., less
## than, greater than or equal, in range) if only we represented the
## branching programs with a data structure that supported these
## queries. A [range tree](https://dl.acm.org/doi/10.1145/356789.356797) would
## let us handle these "theta" joins, for tbe low low cost of a
## polylogarithmic multiplicative factor in space and time.
##
## ### Self-adjusting computation
##
## Finally, we could update the indexed branching programs
## incrementally after small changes to the input data. This might
## sound like a job for streaming engines like [timely dataflow](https://github.com/timelydataflow/timely-dataflow),
## but I think viewing each `_precompute_map_reduce` call as a purely
## functional map/reduce job gives a better fit with [self-adjusting computation](https://www.umut-acar.org/research#h.x3l3dlvx3g5f).
##
## Once we add logic to recycle previously constructed indexes, it
## will probably make sense to allow an initial filtering step before
## map/reduce, with a cache key on the filter function (with closed
## over values and all). We can often implement the filtering more
## efficiently than we can run functions backward, and we'll also
## observe that slightly different filter functions often result
## in not too dissimilar filtered sets. Factoring out this filtering
## can thus enable more reuse of partial precomputed results.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment