Skip to content

Instantly share code, notes, and snippets.

@KeAWang
Forked from willwhitney/tree_stack.py
Last active June 5, 2023 16:05
Show Gist options
  • Save KeAWang/f420ba439a012d969b04211a42f6c9de to your computer and use it in GitHub Desktop.
Save KeAWang/f420ba439a012d969b04211a42f6c9de to your computer and use it in GitHub Desktop.
utils for stacking and unstacking jax pytrees to deal with vmap
import jax
import jax.numpy as jnp
def tree_stack(trees):
"""Takes a list of trees and stacks every corresponding leaf.
For example, given two trees ((a, b), c) and ((a', b'), c'), returns
((stack(a, a'), stack(b, b')), stack(c, c')).
Useful for turning a list of objects into something you can feed to a
vmapped function.
"""
assert isinstance(trees, list)
leaves_list, treedef_list = zip(*map(jax.tree_flatten, trees))
assert len(set(treedef_list)) == 1, "all pytrees must be the same"
grouped_leaves = zip(*leaves_list)
result_leaves = [jnp.stack(l) for l in grouped_leaves]
return treedef_list[0].unflatten(result_leaves)
def tree_unstack(tree):
"""Takes a tree and turns it into a list of trees. Inverse of tree_stack.
For example, given a tree ((a, b), c), where a, b, and c all have first
dimension k, will make k trees
[((a[0], b[0]), c[0]), ..., ((a[k], b[k]), c[k])]
Useful for turning the output of a vmapped function into normal objects.
"""
leaves, treedef = jax.tree_flatten(tree)
if len(leaves) == 0:
return [tree]
leave_lengths = set(map(lambda x: x.shape[0], leaves))
assert (
len(leave_lengths) == 1
), "All non-None pytrees leaves must be of the same size in the leading axis"
new_leaves = zip(*leaves)
new_trees = [treedef.unflatten(l) for l in new_leaves]
return new_trees
inputs = [[None], [(None, jnp.ones(1)), (None, jnp.zeros(1))]]
for inp in inputs:
output = tree_unstack(tree_stack(inp))
assert inp == output
assert tree_unstack(None) == [None]
assert tree_unstack([None]) == [[None]]
assert tree_unstack([None, None]) == [[None, None]]
assert tree_stack([None, None]) is None
try:
assert tree_stack(None)
except AssertionError:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment