Skip to content

Instantly share code, notes, and snippets.

@sklam
Created October 23, 2023 22:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sklam/5e5737137d48d6e5b816d14a90076f1d to your computer and use it in GitHub Desktop.
Save sklam/5e5737137d48d6e5b816d14a90076f1d to your computer and use it in GitHub Desktop.
LoopNest, Shape, ndim inference
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@saulshanabrook
Copy link

saulshanabrook commented Feb 8, 2024

I am coming back to this after finishing some refactoring work on egglog and wanted to check to see if I am understanding things properly before trying to implement the features needed for this.

Thank you again for posting this example! I am excited about getting this sort of thing to work and it's essential to have real use cases like this to drive the development...


Let's say we implement an API like you proposed on top of the existing array api bindings:

from typing import Callable, ClassVar, Iterator

import numpy as np

import egglog.exp.array_api as enp
from egglog import Expr, convert, rewrite, ruleset


class ShapeAPI(Expr):
    def __init__(self, dims: enp.TupleInt) -> None:
        ...

    def deselect(self, axis: enp.TupleInt) -> ShapeAPI:
        ...

    def select(self, axis: enp.TupleInt) -> ShapeAPI:
        ...

    def to_tuple(self) -> enp.TupleInt:
        ...

class OptionalLoopNestAPI(Expr):
    def __init__(self, value: LoopNestAPI) -> None:
        ...

    NONE: ClassVar[OptionalLoopNestAPI]

    def unwrap(self) -> LoopNestAPI:
        ...


class LoopNestAPI(Expr):
    def __init__(self, dim: enp.Int, inner: OptionalLoopNestAPI = OptionalLoopNestAPI.NONE) -> None:
        ...

    @classmethod
    def from_tuple(cls, args: enp.TupleInt) -> OptionalLoopNestAPI:
        ...

    def __iter__(self) -> Iterator[enp.TupleInt]:
        return iter(self.indices)

    @property
    def indices(self) -> enp.TupleTupleInt:
        ...

    def get_dims(self) -> enp.TupleInt:
        ...

    def reduce(self, fn: Callable[[enp.NDArray, enp.TupleInt], enp.NDArray], init: enp.NDArray) -> enp.NDArray:
        ...

With that, a user could write a loopnest function with it:

def linalg_norm_loopnest_egglog(X: enp.NDArray, axis: enp.TupleInt) -> enp.NDArray:
    # peel off the outer shape for result array
    outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
    # get only the inner shape for reduction
    reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()

    return enp.NDArray.from_fn(
        outshape,
        X.dtype,
        lambda k: enp.sqrt(
            LoopNestAPI.from_tuple(reduce_axis)
            .unwrap()
            .reduce(lambda carry, i: carry + enp.real(enp.conj(x := X[i + k]) * x), init=0.0)
        ).to_value(),
    )

By applying some rewrite rules (included below in a dropdown), we could translate away the custom APIs and end up with something like this expression:

def linalg_norm_array_api(X: enp.NDArray, axis: enp.TupleInt) -> enp.NDArray:
    outdim = enp.range_(X.ndim).filter(lambda x: ~axis.contains(x))
    outshape = convert(convert(X.shape, enp.NDArray)[outdim], enp.TupleInt)
    row_axis, col_axis = axis
    return enp.NDArray.from_fn(
        outshape,
        X.dtype,
        lambda k: enp.sqrt(
            enp.int_product(enp.range_(X.shape[row_axis]), enp.range_(X.shape[col_axis]))
            .map_to_ndarray(lambda rc: enp.real(enp.conj(x := X[rc + k]) * x))
            .sum()
        ).to_value(),
    )

Then, using the existing source generation, we could create this Python source from that function:

def linalg_norm_low_level(
    X: np.ndarray[tuple, np.dtype[np.float64]], axis: tuple[int, int]
) -> np.ndarray[tuple, np.dtype[np.float64]]:
    # # If X ndim>=3 and axis is a 2-tuple
    assert X.ndim >= 3
    assert len(axis) == 2
    #  Y - 2
    outdim = [dim for dim in range(X.ndim) if dim not in axis]

    outshape = tuple(np.asarray(X.shape)[outdim])

    res = np.zeros(outshape, dtype=X.dtype)
    row_axis, col_axis = axis
    for k in np.ndindex(outshape):
        tmp = 0.0
        for row in range(X.shape[row_axis]):
            for col in range(X.shape[col_axis]):
                idx = (row, col, *k)
                x = X[idx]
                tmp += (x.conj() * x).real
        res[k] = np.sqrt(tmp)
    return res

Does that seem about right?

Implementation

I haven't tried any of this code, and to implement it would require adding functions as values to egglog. I believe this could be done without any (or minimal) changes to the core, but before starting on it, I wanted to check in to see if this use case is accurate enough to give it a go.

Also implementing generic functions and types would go a long way in making the code a bit more fluid, but in the examples I put here, I assume we don't have that feature yet. It should all still be possible without it, just more repetition.

I believe I could add them as a Python pre-processor and this use case would also be a good model to show how they help with readability.


Rewrite rules:
@ruleset
def shape_api_ruleset(dims: enp.TupleInt, axis: enp.TupleInt):  # noqa: ANN201
    yield rewrite(ShapeAPI(dims).deselect(axis)).to(
        ShapeAPI(enp.range_(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i]))
    )
    yield rewrite(ShapeAPI(dims).select(axis)).to(
        ShapeAPI(enp.range_(dims.length()).filter(lambda i: axis.contains(i)).map(lambda i: dims[i]))
    )
    yield rewrite(ShapeAPI(dims).to_tuple()).to(dims)


@ruleset
def loopnest_api_ruleset(
    head: enp.Int,
    tail: enp.TupleInt,
    lna: LoopNestAPI,
    fn: Callable[[enp.NDArray, enp.TupleInt], enp.NDArray],
    init: enp.NDArray,
    dim: enp.Int,
):
    # from_tuple
    yield rewrite(LoopNestAPI.from_tuple(enp.TupleInt.EMPTY)).to(OptionalLoopNestAPI.NONE)
    yield rewrite(
        LoopNestAPI.from_tuple(enp.TupleInt.some(head, tail)),
    ).to(
        OptionalLoopNestAPI(LoopNestAPI(head, LoopNestAPI.from_tuple(tail))),
    )
    # reduce
    yield rewrite(lna.reduce(fn, init)).to(lna.indices.reduce_ndarray(fn, init))
    # get_dims
    yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI.NONE).get_dims()).to(enp.TupleInt(dim))
    yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI(lna)).get_dims()).to(enp.TupleInt(dim) + lna.get_dims())
    # indices
    yield rewrite(lna.indices).to(lna.get_dims().map_tuple_int(enp.range_).product())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment