Skip to content

Instantly share code, notes, and snippets.

@awni
Last active July 1, 2024 20:59
Show Gist options
  • Save awni/e20a288011adac0d0bdaf73e5296ca4f to your computer and use it in GitHub Desktop.
Save awni/e20a288011adac0d0bdaf73e5296ca4f to your computer and use it in GitHub Desktop.
Working around operations with data-dependent shapes in MLX

Ops with Data Dependent Shapes

This is a short article on a common type of not-yet-supported operation in MLX: ops where the output shape depends on the input data. Here's an outline:

  1. An introduction to these operations, followed by an explanation of why they are challenging to implement efficiently.
  2. A discussion on when and how to work-around these missing operations with a couple of examples.

The Ops

A common question about MLX is why it doesn't support boolean indices.

What's a boolean index? In frameworks like NumPy and PyTorch, you can index an array with a boolean mask array. Here's an example:

a = np.arange(4)
idx = np.array([True, False, False, True])

# Prints: array([0, 3])
print(a[idx])

The challenge with this operation is that the output shape depends on the input data. Why is that a challenge? Most accelerators, like Nvidia or Apple GPUs, use an asynchronous programming model. You schedule operations to run without waiting for previous operations to finish.

Imagine if idx was the output of another computation. In that case, when the operation a[idx] is ready to be scheduled, idx is not yet safe to inspect. That means we don't know how big the output a[idx] is. That in turn means we can't easily do things like allocate memory for the output, determine the number of GPU threads to use in the kernel, etc.

We could always synchronize with the accelerator before scheduling the indexing operation, but synchronizations are expensive. Rather than implicitly eating this cost, MLX does not yet support these operations.

Boolean masks are not the only example of operations for which the output shape depends on the input data. Other operations include np.nonzero and the unary version of np.where.

Jax has a nice compromise for these types of operations. It adds an explicit size parameter to for example jax.numpy.nonzero and the unary jax.numpy.where. The size parameter specifies the maximum size of the output. Any outputs bigger than the maximum size are truncated. Any outputs smaller are padded with a user specified value.

Work-around

So if these operations are not supported, how do you express the same computation with existing ops? Is it even possible?

The answer to the second question is pretty straight-forward. If the final output of the computation has a fixed shape which is independent of the input data, then it should be possible to implement the computation with existing operations. Exactly how to do this can be tricky. It's kind of like vectorizing code. It often takes some careful thought, but with practice it gets easier.

Let's look at a simple example. Suppose you want to count the number of nonzero elements in one array at the indices where another array is exactly one. Here's one way:

counts = b[a == 1].sum()

That will work in NumPy or Pytorch (or non-compiled Jax). Notice that the final output shape, the shape of counts, is independent of the number of ones in a. The value counts is a scalar, so it has an empty shape. That means we can use a fixed shape sequence of operations to compute it. In this case it's simple, we can use the ternary mx.where.

counts = mx.where(a == 1, b, 0).sum()

Here's a slightly more involved example. Say you want to compute the negative entropy $\sum p_i \log p_i$. To do this in a numerically safe way, you need to avoid computing $\log p_i$ when $p_i = 0$. Here's one way:

p_nz = p[np.nonzero(p)]
entropy = np.sum(p_nz * np.log(p_nz))

Again, notice the final output has a fixed size - it's a scalar. To compute this in a numerically safe way in MLX you can instead do:

mask = p == 0
entropy = mx.sum(mx.where(mask, 0, p * mx.log(p)))

You might be noticing a pattern here. Usually the ternary version of mx.where, often with a mask of some sort, is needed in working around the missing operations.

What if the final output actually has a data-dependent shape? In that case you may actually need to use an operation like boolean masking, np.nonzero, or the unary np.where. In those cases one option is to convert to NumPy and back to MLX if needed. In the future, MLX may provide size hints like Jax or possibly allow these types of operations with implicit synchronizations like PyTorch. For now, here are a couple of useful guidelines:

  • Synchronize (e.g. convert to NumPy) as little as possible. For example:
for a in b:
  c = np.array(a)
  ...

is worse than:

for c in np.array(b):
   ...
  • Do as much in MLX as you can. Try to delay conversions to NumPy to the very end of the computation. Or use NumPy only at the beginning and then convert to MLX for the remainder.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment