Created
March 2, 2021 23:37
-
-
Save dfm/a2db466f46ab931947882b08b2f21558 to your computer and use it in GitHub Desktop.
Thanks for this! One note I found while experimenting with this on aesara 2.0.12: jax_funcify_JaxOp
seems to require an extra keyword arguments node
and storage_map
, so this tweak makes the code above work for me:
@jax_funcify.register(JaxOp)
def jax_funcify_JaxOp(op, *args, **kwargs):
func = op.jax_fn
return func
I hope that's sensible.
@bmorris3: Yeah - this interface has been a moving target so I haven't been following it too closely, so I'm not sure that I know enough to comment, but seems sensible enough :D
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@peterroelants: these questions are both moot if you're only using the jaxified version of the function. Perform is only called when evaluating the op using aesara. So this means that you could use this op as a deterministic using original PyMC3 or the Jax backend, and on the Jax backend this would reduce directly to just the jax function.
But to answer them directly: