Created May 22, 2021 13:54
from jaxlib.xla_extension import DeviceArray as JaxArray
from jax.tree_util import register_pytree_node
from plum import dispatch, parametric
from typing import Tuple
import jax.numpy as jnp
def register_pytree_parametric(cls):
lambda xs: (tuple(xs.__dict__.values()), None),
lambda _, xs: cls(xs)
return cls
class Test:
a: JaxArray
def __init__(self, args: Tuple):
self.a, = args
def __init__(self, a):
