Last active
April 13, 2021 20:27
-
-
Save twiecki/a77104299535b64b58953de3c84df56f to your computer and use it in GitHub Desktop.
stochastic_volatility.ipynb
How do you mean?
…On Mon, Apr 12, 2021, 21:01 Brandon T. Willard ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
I just noticed that this example isn't optimizing the FunctionGraph.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<https://gist.github.com/a77104299535b64b58953de3c84df56f#gistcomment-3703154>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFETGGVJP47QKPRU54KYL3TIM7PFANCNFSM42Z3R32A>
.
Doing something like the following will optimize the FunctionGraph
in roughly the same way that aesara.function
does:
from aesara.compile.mode import FAST_RUN
fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
_ = FAST_RUN.optimizer.optimize(fgraph)
Without that step, the JAX function will take the exact form of the log-likelihood graph determined by the Distribution.logp
implementations (i.e. no CSE, fusions, in-place operations, etc.).
I suppose pm.sample()
already does this?
This looks like something we need to update in PyMC3, as well.
Here's a quick comparison of the timing with and without graph optimizations (the example/model
is taken from this notebook):
fgraph = model.logp.f.maker.fgraph
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 198 µs ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 236 µs ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I just noticed that this example isn't optimizing the
FunctionGraph
.