Skip to content

Instantly share code, notes, and snippets.

@willwhitney
Last active December 6, 2023 01:54
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save willwhitney/dd89cac6a5b771ccff18b06b33372c75 to your computer and use it in GitHub Desktop.
Save willwhitney/dd89cac6a5b771ccff18b06b33372c75 to your computer and use it in GitHub Desktop.
utils for stacking and unstacking jax pytrees to deal with vmap
import numpy as np
from jax import numpy as jnp
from jax.lib import pytree
def tree_stack(trees):
"""Takes a list of trees and stacks every corresponding leaf.
For example, given two trees ((a, b), c) and ((a', b'), c'), returns
((stack(a, a'), stack(b, b')), stack(c, c')).
Useful for turning a list of objects into something you can feed to a
vmapped function.
"""
leaves_list = []
treedef_list = []
for tree in trees:
leaves, treedef = pytree.flatten(tree)
leaves_list.append(leaves)
treedef_list.append(treedef)
grouped_leaves = zip(*leaves_list)
result_leaves = [jnp.stack(l) for l in grouped_leaves]
return treedef_list[0].unflatten(result_leaves)
def tree_unstack(tree):
"""Takes a tree and turns it into a list of trees. Inverse of tree_stack.
For example, given a tree ((a, b), c), where a, b, and c all have first
dimension k, will make k trees
[((a[0], b[0]), c[0]), ..., ((a[k], b[k]), c[k])]
Useful for turning the output of a vmapped function into normal objects.
"""
leaves, treedef = pytree.flatten(tree)
n_trees = leaves[0].shape[0]
new_leaves = [[] for _ in range(n_trees)]
for leaf in leaves:
for i in range(n_trees):
new_leaves[i].append(leaf[i])
new_trees = [treedef.unflatten(l) for l in new_leaves]
return new_trees
if __name__ == "__main__":
def make_tree():
sizes = ((1, 2), (3, 1), (3,))
make_leaf_np = lambda i: np.random.uniform(size=sizes[i])
make_leaf = lambda i: jnp.array(make_leaf_np(i))
return ((make_leaf(0), make_leaf(1)), make_leaf(2))
trees = [make_tree() for _ in range(3)]
print("Before")
print(trees)
print("\nStacked")
stacked = tree_stack(trees)
print(stacked)
print("\nUnstacked")
unstacked = tree_unstack(stacked)
print(unstacked)
@Habush
Copy link

Habush commented Oct 31, 2022

Thank you! This is very useful.

One slight improvement for api usage is to use jax.tree_util.tree_flatten instead of the internal api pytree.flatten

@v0lta
Copy link

v0lta commented Jan 11, 2023

Thank you very much. It seems jax.tree_util.tree_flatten is now the only way to make it work.

@danielkelshaw
Copy link

Alternative definition, I believe it does the same:

import jax.numpy as jnp
import jax.tree_util as jtu


def tree_stack(trees):
    return jtu.tree_map(lambda *v: jnp.stack(v), *trees)

def tree_unstack(tree):
    leaves, treedef = jtu.tree_flatten(tree)
    return [treedef.unflatten(leaf) for leaf in zip(*leaves, strict=True)]

@ayaka14732
Copy link

@danielkelshaw Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment