Skip to content

Instantly share code, notes, and snippets.

@enijkamp
Created December 12, 2021 19:04
Show Gist options
  • Save enijkamp/354f037bd1dd00df1a714d86f89ed0a7 to your computer and use it in GitHub Desktop.
Save enijkamp/354f037bd1dd00df1a714d86f89ed0a7 to your computer and use it in GitHub Desktop.
leave_names.py
def tree_flatten_with_names(pytree, is_leaf, path="", to_id=id):
id_to_name = {}
if getattr(pytree, "items", None):
for k, v in pytree.items():
k_path = f"{path}/{k}"
if is_leaf(v):
id_to_name[to_id(v)] = k_path
else:
id_to_name = {
**id_to_name,
**tree_flatten_with_names(v, is_leaf=is_leaf, path=k_path),
}
elif getattr(pytree, "__getitem__", None):
for v in pytree:
if is_leaf(v):
id_to_name[to_id(v)] = path
else:
id_to_name = {
**id_to_name,
**tree_flatten_with_names(v, is_leaf=is_leaf, path=path),
}
else:
id_to_name[to_id(pytree)] = path
return id_to_name
def tree_leaves_with_names(pytree, to_id=id):
leaves = jax.tree_leaves(pytree)
is_leaf = lambda x: not isinstance(x, list) and to_id(x) in [
to_id(x) for x in leaves
]
return tree_flatten_with_names(pytree, is_leaf)
def get_tree_leaves_names_reduced(pytree) -> List[str]:
leaves_ids = tree_leaves_with_names(pytree, to_id=id)
leaves = jax.tree_leaves(pytree)
return [leaves_ids[id(l)] for l in leaves]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment