Skip to content

Instantly share code, notes, and snippets.

@amifalk
Created August 4, 2024 19:30
Show Gist options
  • Save amifalk/e21059da7f0c0ecb3db8240604413998 to your computer and use it in GitHub Desktop.
Save amifalk/e21059da7f0c0ecb3db8240604413998 to your computer and use it in GitHub Desktop.
from typing import Callable, Sequence
from penzai import pz
# This prototype version of map requires that `f`
# 1. is already nmapped
# 2. takes one named array argument without positional axes and
# 3. returns a named array without any positional axes.
# The `batch_axes` of `map`, must not be modified by `f`.
def map(f: Callable, x: pz.nx.NamedArray, batch_axes: Sequence[pz.nx.AxisName], batch_size: int):
tmp_concat_axis_id = pz.nx.TmpPosAxisMarker()
tmp_batch_axis_id = pz.nx.TmpPosAxisMarker()
to_batch_axis_sizes = tuple(x.named_shape[axis] for axis in batch_axes)
concat_x = x.untag(*batch_axes).flatten().tag(tmp_concat_axis_id)
num_batches = concat_x.named_shape[tmp_concat_axis_id] // batch_size
total_batch_elems = num_batches * batch_size
scan_concat_x = concat_x[{tmp_concat_axis_id: pz.slice[:total_batch_elems]}]
remainder_concat_x = concat_x[{tmp_concat_axis_id: pz.slice[total_batch_elems:]}]
# reshape to batch_size and scan across each batch
batched_concat_x = (scan_concat_x
.untag(tmp_concat_axis_id)
.reshape(-1, num_batches)
.tag(tmp_concat_axis_id, tmp_batch_axis_id))
_, batched_concat_y = pz.nx.scan(
f=lambda _, x: ((), f(x)), axis=tmp_concat_axis_id, init=(), xs=batched_concat_x
)
concat_y = (batched_concat_y
.untag(tmp_batch_axis_id, tmp_concat_axis_id)
.reshape(-1)
.tag(tmp_concat_axis_id))
remainder_concat_y = f(remainder_concat_x)
concat_y = pz.nx.concatenate([concat_y, remainder_concat_y], axis_name=tmp_concat_axis_id)
y = concat_y.untag(tmp_concat_axis_id).reshape(to_batch_axis_sizes).tag(*batch_axes)
return y
arr = pz.nx.ones({"a": 10, "b": 3, "c": 6}) * pz.nx.arange("c", 6)
def my_func(arr):
return arr + 2
map(my_func, arr, batch_axes=["a", "b"], batch_size=7)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment