Skip to content

Instantly share code, notes, and snippets.

@twiecki
Last active April 13, 2021 20:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save twiecki/a77104299535b64b58953de3c84df56f to your computer and use it in GitHub Desktop.
Save twiecki/a77104299535b64b58953de3c84df56f to your computer and use it in GitHub Desktop.
stochastic_volatility.ipynb
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@brandonwillard
Copy link

I just noticed that this example isn't optimizing the FunctionGraph.

@twiecki
Copy link
Author

twiecki commented Apr 12, 2021 via email

@brandonwillard
Copy link

brandonwillard commented Apr 12, 2021

Doing something like the following will optimize the FunctionGraph in roughly the same way that aesara.function does:

from aesara.compile.mode import FAST_RUN

fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
_ = FAST_RUN.optimizer.optimize(fgraph)

Without that step, the JAX function will take the exact form of the log-likelihood graph determined by the Distribution.logp implementations (i.e. no CSE, fusions, in-place operations, etc.).

@twiecki
Copy link
Author

twiecki commented Apr 13, 2021

I suppose pm.sample() already does this?

@brandonwillard
Copy link

This looks like something we need to update in PyMC3, as well.

Here's a quick comparison of the timing with and without graph optimizations (the example/model is taken from this notebook):

fgraph = model.logp.f.maker.fgraph
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 198 µs ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 236 µs ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment