Skip to content

Instantly share code, notes, and snippets.

@odashi
Last active February 19, 2021 23:47
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save odashi/813810a5bc06724ea3643456f8d3942d to your computer and use it in GitHub Desktop.
Save odashi/813810a5bc06724ea3643456f8d3942d to your computer and use it in GitHub Desktop.
Augmented dataclass for JAX pytree.
import dataclasses as dc
from jax import tree_util as jt
def register_jax_dataclass(cls):
"""Registers a dataclass as a JAX pytree."""
if not dc.is_dataclass(cls):
raise TypeError('%s is not a dataclass.' % cls)
keys = [field.name for field in dc.fields(cls)]
def _flatten(obj):
return [getattr(obj, key) for key in keys], None
def _unflatten(_, children):
return cls(**dict(zip(keys, children)))
jt.register_pytree_node(cls, _flatten, _unflatten)
return cls
def jax_dataclass(cls):
"""Decorator function to define a dataclass with JAX bindings."""
return register_jax_dataclass(dc.dataclass(cls))
@jax_dataclass
class Data:
foo: int
bar: float
a = Data(1, 2.3)
leaves, treedef = jt.tree_flatten(a)
b = jt.tree_unflatten(treedef, leaves)
assert a == b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment