Skip to content

Instantly share code, notes, and snippets.

@pabloferz
Created May 22, 2021 13:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pabloferz/27cdb6150b303bffe6aa5130ecd0931e to your computer and use it in GitHub Desktop.
Save pabloferz/27cdb6150b303bffe6aa5130ecd0931e to your computer and use it in GitHub Desktop.
from jaxlib.xla_extension import DeviceArray as JaxArray
from jax.tree_util import register_pytree_node
from plum import dispatch, parametric
from typing import Tuple
import jax.numpy as jnp
def register_pytree_parametric(cls):
register_pytree_node(
cls,
lambda xs: (tuple(xs.__dict__.values()), None),
lambda _, xs: cls(xs)
)
return cls
@parametric
class Test:
a: JaxArray
@dispatch
def __init__(self, args: Tuple):
self.a, = args
@dispatch
def __init__(self, a):
self.__init__((jnp.asarray(a),))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment