Skip to content

Instantly share code, notes, and snippets.

@mwitiderrick
Created July 11, 2022 12:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mwitiderrick/04b6fe183c2d414b993b029f3fbec237 to your computer and use it in GitHub Desktop.
Save mwitiderrick/04b6fe183c2d414b993b029f3fbec237 to your computer and use it in GitHub Desktop.
example_trees = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]
# Let's see how many leaves they have:
for pytree in example_trees:
leaves = jax.tree_leaves(pytree)
print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
# [1, 'a', <object object at 0x7f280a01f6d0>] has 3 leaves: [1, 'a', <object object at 0x7f280a01f6d0>]
# (1, (2, 3), ()) has 3 leaves: [1, 2, 3]
# [1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]
# {'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]
# DeviceArray([1, 2, 3], dtype=int64) has 1 leaves: [DeviceArray([1, 2, 3], dtype=int64)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment