This is a short article on a common type of not-yet-supported operation in MLX: ops where the output shape depends on the input data. Here's an outline:
- An introduction to these operations, followed by an explanation of why they are challenging to implement efficiently.
- A discussion on when and how to work-around these missing operations with a couple of examples.
A common question about MLX is why it doesn't support boolean indices.
What's a boolean index? In frameworks like NumPy and PyTorch, you can index an array with a boolean mask array. Here's an example:
a = np.arange(4)
idx = np.array([True, False, False, True])
# Prints: array([0, 3])
print(a[idx])
The challenge with this operation is that the output shape depends on the input data. Why is that a challenge? Most accelerators, like Nvidia or Apple GPUs, use an asynchronous programming model. You schedule operations to run without waiting for previous operations to finish.
Imagine if idx
was the output of another computation. In that case, when the
operation a[idx]
is ready to be scheduled, idx
is not yet safe to inspect.
That means we don't know how big the output a[idx]
is. That in turn means we
can't easily do things like allocate memory for the output, determine the
number of GPU threads to use in the kernel, etc.
We could always synchronize with the accelerator before scheduling the indexing operation, but synchronizations are expensive. Rather than implicitly eating this cost, MLX does not yet support these operations.
Boolean masks are not the only example of operations for which the output shape
depends on the input data. Other operations include
np.nonzero
and the unary version of
np.where
.
Jax has a nice compromise for these types of operations. It adds an explicit
size
parameter to for example
jax.numpy.nonzero
and the unary
jax.numpy.where
.
The size
parameter specifies the maximum size of the output. Any outputs
bigger than the maximum size are truncated. Any outputs smaller are padded with a
user specified value.
So if these operations are not supported, how do you express the same computation with existing ops? Is it even possible?
The answer to the second question is pretty straight-forward. If the final output of the computation has a fixed shape which is independent of the input data, then it should be possible to implement the computation with existing operations. Exactly how to do this can be tricky. It's kind of like vectorizing code. It often takes some careful thought, but with practice it gets easier.
Let's look at a simple example. Suppose you want to count the number of nonzero elements in one array at the indices where another array is exactly one. Here's one way:
counts = b[a == 1].sum()
That will work in NumPy or Pytorch (or non-compiled Jax). Notice that the final
output shape, the shape of counts
, is independent of the number of ones in
a
. The value counts
is a scalar, so it has an empty shape. That means we
can use a fixed shape sequence of operations to compute it. In this case it's
simple, we can use the ternary
mx.where
.
counts = mx.where(a == 1, b, 0).sum()
Here's a slightly more involved example. Say you want to compute the negative
entropy
p_nz = p[np.nonzero(p)]
entropy = np.sum(p_nz * np.log(p_nz))
Again, notice the final output has a fixed size - it's a scalar. To compute this in a numerically safe way in MLX you can instead do:
mask = p == 0
entropy = mx.sum(mx.where(mask, 0, p * mx.log(p)))
You might be noticing a pattern here. Usually the ternary version of
mx.where
, often with a mask of some sort, is needed in working around the
missing operations.
What if the final output actually has a data-dependent shape? In that case you
may actually need to use an operation like boolean masking, np.nonzero
, or
the unary np.where
. In those cases one option is to convert to NumPy and back
to MLX if needed. In the future, MLX may provide size hints like Jax or
possibly allow these types of operations with implicit synchronizations like
PyTorch. For now, here are a couple of useful guidelines:
- Synchronize (e.g. convert to NumPy) as little as possible. For example:
for a in b:
c = np.array(a)
...
is worse than:
for c in np.array(b):
...
- Do as much in MLX as you can. Try to delay conversions to NumPy to the very end of the computation. Or use NumPy only at the beginning and then convert to MLX for the remainder.