Skip to content

Instantly share code, notes, and snippets.

@makdoudN
Forked from willwhitney/tree_stack.py
Last active April 1, 2021 07:47
Show Gist options
  • Save makdoudN/849641f63726da3685acc3e58f59ffc7 to your computer and use it in GitHub Desktop.
Save makdoudN/849641f63726da3685acc3e58f59ffc7 to your computer and use it in GitHub Desktop.
utility apply reduce function to a pytree (or a pytree of pytree)
import numpy as np
from jax import numpy as jnp
from jax.lib import pytree
def reduce(trees, op=jnp.stack, op_kwargs: dict = {}):
"""Takes a list of trees and apply op to every corresponding leaf.
For example, assuming op=`jnp.stack`,
given two trees ((a, b), c) and ((a', b'), c'), returns
((stack(a, a'), stack(b, b')), stack(c, c')).
# Start as a fork of: https://gist.github.com/willwhitney/dd89cac6a5b771ccff18b06b33372c75
"""
leaves_list = []
treedef_list = []
for tree in trees:
leaves, treedef = pytree.flatten(tree)
leaves_list.append(leaves)
treedef_list.append(treedef)
grouped_leaves = zip(*leaves_list)
result_leaves = [op(l, **op_kwargs) for l in grouped_leaves]
return treedef_list[0].unflatten(result_leaves)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment