Skip to content

Instantly share code, notes, and snippets.

Created May 22, 2021 13:54
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
What would you like to do?
from jaxlib.xla_extension import DeviceArray as JaxArray
from jax.tree_util import register_pytree_node
from plum import dispatch, parametric
from typing import Tuple
import jax.numpy as jnp
def register_pytree_parametric(cls):
lambda xs: (tuple(xs.__dict__.values()), None),
lambda _, xs: cls(xs)
return cls
class Test:
a: JaxArray
def __init__(self, args: Tuple):
self.a, = args
def __init__(self, a):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment