Skip to content

Instantly share code, notes, and snippets.

@Edenhofer
Last active January 12, 2024 15:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Edenhofer/34207ad5b2b60e564e21bde9c350efd6 to your computer and use it in GitHub Desktop.
Save Edenhofer/34207ad5b2b60e564e21bde9c350efd6 to your computer and use it in GitHub Desktop.
Stupid/sequential map
# Copyright(C) 2022-2023 Gordian Edenhofer
# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
from functools import partial
import jax
from jax import lax
from jax import numpy as jnp
def _int_or_none(x):
return isinstance(x, int) or x is None
def _fun_reord(_, mapped, *, fun, unmapped, unflatten, in_axes):
un, mapped = list(unmapped), list(mapped)
assert len(un) + len(mapped) == len(in_axes)
args = tuple(un.pop(0) if a is None else mapped.pop(0) for a in in_axes)
y = fun(*unflatten(args))
return None, y
def _moveaxis(a, source, destination):
# Ensure that arrays are never completely unnecessarily copied
if source == destination:
return a
return jnp.moveaxis(a, source, destination)
def _generic_smap(fun, in_axes, out_axes, unroll, *x, _scan=lax.scan, **k):
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
if k:
raise TypeError("keyword arguments are not allowed in map")
if isinstance(in_axes, int):
in_axes = tree_map(lambda _: in_axes, x)
elif isinstance(in_axes, tuple):
if len(in_axes) != len(x):
ve = f"`in_axes` {in_axes!r} and input {x!r} must be of same length"
raise ValueError(ve)
new_inax = []
for el, i in zip(x, in_axes):
new_inax.append(tree_map(lambda _: i, el) if _int_or_none(i) else i)
in_axes = tuple(new_inax)
else:
te = f"`in_axes` must be an int or tuple of pytrees/int; got {in_axes!r}"
raise TypeError(te)
x, x_td = tree_flatten(x)
in_axes, in_axes_td = tree_flatten(in_axes, is_leaf=_int_or_none)
if in_axes_td != x_td:
ve = f"`in_axes` {in_axes_td!r} incompatible with input `*x` {x_td!r}"
raise ValueError(ve)
unmapped = []
mapped = []
for i, el in zip(in_axes, x):
if i is None:
unmapped.append(el)
elif isinstance(i, int):
mapped.append(_moveaxis(el, i, 0))
else:
raise TypeError(f"expected `in_axes` index of type int; got {i!r}")
fun_reord = partial(
_fun_reord,
fun=fun,
unmapped=unmapped,
unflatten=partial(tree_unflatten, x_td),
in_axes=in_axes
)
_, y = _scan(fun_reord, None, mapped, unroll=unroll)
if out_axes is None:
out_axes, out_axes_td = tree_flatten(out_axes)
if isinstance(out_axes, int):
out_axes = tree_map(lambda el: out_axes if el is not None else el, y)
out_axes, out_axes_td = tree_flatten(out_axes)
else:
out_axes, out_axes_td = tree_flatten(out_axes, is_leaf=_int_or_none)
y, y_td = tree_flatten(y)
if y is not None and out_axes_td != y_td:
ve = f"`out_axes` {out_axes_td!r} incompatible with output {y_td!r}"
raise ValueError(ve)
out = []
for i, el in zip(out_axes, y):
if i is None:
out.append(unmapped.pop(0))
elif isinstance(i, int):
out.append(_moveaxis(el, 0, i))
else:
raise TypeError(f"expected `out_axes` index of type int; got {i!r}")
return tree_unflatten(y_td, out)
# The function over which to `scan` depends on the data. This leads to
# unnecessary recompiles. Ensure scan is compiled only once by compiling the
# whole data dependence.
_smap = jax.jit(
_generic_smap,
static_argnames=("fun", "in_axes", "out_axes", "unroll", "_scan")
)
def smap(fun, in_axes=0, out_axes=0, *, unroll=1):
"""Stupid/sequential map.
Many of JAX's control flow logic reduces to a simple `jax.lax.scan`. This
function is one of these. In contrast to `jax.lax.map` or
`jax.lax.fori_loop`, it behaves much like `jax.vmap`. In fact, it
re-implements `in_axes` and `out_axes` and can be used in much the same way
as `jax.vmap`. However, instead of batching the input, it works through it
sequentially.
This implementation makes no claim on being efficient. It explicitly swaps
around axis in the input and output, potentially allocating more memory
than strictly necessary and worsening the memory layout.
For the semantics of `in_axes` and `out_axes` see `jax.vmap`. For the
semantics of `unroll` see `jax.lax.scan`.
"""
return partial(_smap, fun, in_axes, out_axes, unroll)
@partial(jax.jit, donate_argnames=("x", ))
def _unsafe_index_update_inplace(x, idx, y):
return x.at[idx].set(y)
def _lscan(f, init, xs, length=None, unroll=1):
if unroll != 1:
raise NotImplementedError()
if xs is None:
xs = [None] * length
carry = init
ys = None
first_leave = jax.tree_util.tree_leaves(xs)[0]
like_device = first_leave.device(
) if hasattr(first_leave, "device") else None
length = first_leave.shape[0] if length is None else length
for i in range(length):
carry, y = f(carry, jax.tree_map(lambda x: x[i], xs))
if ys is None:
# NOTE, `empty_like` will always allocate on the primary device even
# if `y` is on a different device. Forcefully allocate on the same
# device as `y`.
with jax.default_device(like_device):
ys = jax.tree_map(
lambda x: jnp.
empty_like(x, shape=(length, ) + jnp.shape(x)), y
)
ys = jax.tree_map(
lambda ys, y: _unsafe_index_update_inplace(ys, i, y), ys, y
)
return carry, ys
def lmap(fun, in_axes=0, out_axes=0):
return partial(_generic_smap, fun, in_axes, out_axes, 1, _scan=_lscan)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment