Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Created July 12, 2021 08:53
Show Gist options
  • Save ricardoV94/6089a8c46a0e19665f01c79ea04e1cb2 to your computer and use it in GitHub Desktop.
Save ricardoV94/6089a8c46a0e19665f01c79ea04e1cb2 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"You are running the v4 development version of PyMC3 which currently still lacks key features. You probably want to use the stable v3 instead which you can either install via conda or find on the v3 GitHub branch: https://github.com/pymc-devs/pymc3/tree/v3\n"
]
}
],
"source": [
"import time\n",
"\n",
"import aesara\n",
"import aesara.tensor as at\n",
"import numpy as np\n",
"\n",
"import pymc3 as pm\n",
"from pymc3.aesaraf import make_shared_replacements, join_nonshared_inputs\n",
"from pymc3.smc.smc import logp_forw"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"with pm.Model() as m:\n",
" x = pm.Normal('x', size=(2, 3))\n",
" y = pm.HalfNormal('y', size=3)\n",
" obs = pm.Normal('obs', x, y, observed=np.ones((2, 3)))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"smc = pm.smc.SMC(model=m)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"initial_values = smc.model.initial_point\n",
"shared = make_shared_replacements(initial_values, smc.variables, smc.model)\n",
"f = logp_forw(\n",
" initial_values,\n",
" [smc.model.varlogpt],\n",
" smc.variables,\n",
" shared,\n",
")\n",
"f.trust_input = True"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'x': array([[ 0.46346823, 0.08374901, -0.56523375],\n",
" [ 1.8669097 , -0.88971022, -0.29969549]]),\n",
" 'y_log__': array([-1.03109386, -0.83287868, -0.29418776])}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"initial_values"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"nsamples = 500\n",
"samples = [np.random.rand(9) for _ in range(nsamples)]\n",
"samples[0][:] = 0"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[array(-7.69100526),\n",
" array(-8.59940111),\n",
" array(-13.42063878),\n",
" array(-10.12152872),\n",
" array(-10.00448132),\n",
" array(-9.84893411),\n",
" array(-9.49123258),\n",
" array(-12.18226421),\n",
" array(-11.28890955),\n",
" array(-10.05562276)]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[f(s) for s in samples[:10]]"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Dumb graph replication"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Compilation time = 65.263s\n",
"Number of graph nodes 5000\n"
]
}
],
"source": [
"[logp0], inarray0 = join_nonshared_inputs(\n",
" initial_values,\n",
" [smc.model.varlogpt],\n",
" smc.variables,\n",
" shared,\n",
")\n",
"\n",
"tensor_type = inarray0.type\n",
"inarrays = [tensor_type() for _ in range(nsamples)]\n",
"logps = [pm.CallableTensor(logp0)(inarray) for inarray in inarrays]\n",
"\n",
"start = time.time()\n",
"f_replication = aesara.function(inarrays, logps)\n",
"f_replication.trust_input = True\n",
"end = time.time()\n",
"print(f'Compilation time = {end - start:.3f}s')\n",
"print(f'Number of graph nodes {len(f_replication.maker.fgraph.toposort())}')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[array(-7.69100526),\n",
" array(-8.59940111),\n",
" array(-13.42063878),\n",
" array(-10.12152872),\n",
" array(-10.00448132),\n",
" array(-9.84893411),\n",
" array(-9.49123258),\n",
" array(-12.18226421),\n",
" array(-11.28890955),\n",
" array(-10.05562276)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f_replication(*samples)[:10]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12.8 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit f_replication(*samples)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.56 ms ± 89.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit [f(s) for s in samples]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"np.testing.assert_almost_equal(\n",
" [f(s) for s in samples],\n",
" f_replication(*samples))"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Create a logp_op\n",
"\n",
"### Dumb graph replication"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Compilation time = 4.366s\n",
"Number of graph nodes 501\n"
]
}
],
"source": [
"[logp0], inarray0 = join_nonshared_inputs(\n",
" initial_values,\n",
" [smc.model.varlogpt],\n",
" smc.variables,\n",
" shared,\n",
")\n",
"\n",
"logp_op = aesara.compile.builders.OpFromGraph([inarray0], [logp0])\n",
"\n",
"tensor_type = inarray0.type\n",
"inarrays = [tensor_type() for _ in range(nsamples)]\n",
"logps = pm.math.stack([logp_op(inarray) for inarray in inarrays])\n",
"\n",
"start = time.time()\n",
"f_replication_op = aesara.function(inarrays, logps)\n",
"f_replication_op.trust_input = True\n",
"end = time.time()\n",
"print(f'Compilation time = {end - start:.3f}s')\n",
"print(f'Number of graph nodes {len(f_replication_op.maker.fgraph.toposort())}')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([ -7.69100526, -8.59940111, -13.42063878, -10.12152872,\n",
" -10.00448132, -9.84893411, -9.49123258, -12.18226421,\n",
" -11.28890955, -10.05562276])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f_replication_op(*samples)[:10]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.92 ms ± 182 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit f_replication_op(*samples)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.55 ms ± 81.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit [f(s) for s in samples]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"np.testing.assert_almost_equal(\n",
" [f(s) for s in samples],\n",
" f_replication_op(*samples))"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Use a Scan"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Compilation time = 1.299s\n",
"Number of graph nodes 6\n"
]
}
],
"source": [
"[logp0], inarray0 = join_nonshared_inputs(\n",
" initial_values,\n",
" [smc.model.varlogpt],\n",
" smc.variables,\n",
" shared,\n",
")\n",
"\n",
"logp_op = aesara.compile.builders.OpFromGraph([inarray0], [logp0])\n",
"\n",
"tensor_type = inarray0.type\n",
"inarrays = at.stack([tensor_type() for _ in range(nsamples)])\n",
"\n",
"result, _ = aesara.scan(\n",
" fn = lambda inarray: logp_op(inarray),\n",
" outputs_info=None,\n",
" sequences=[inarrays,],\n",
" strict=True\n",
")\n",
"\n",
"start = time.time()\n",
"f_scan_op = aesara.function([inarrays], result)\n",
"# f_scan_op.trust_input = True # Kernel dies without this\n",
"end = time.time()\n",
"print(f'Compilation time = {end - start:.3f}s')\n",
"print(f'Number of graph nodes {len(f_scan_op.maker.fgraph.toposort())}')"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([ -7.69100526, -8.59940111, -13.42063878, -10.12152872,\n",
" -10.00448132, -9.84893411, -9.49123258, -12.18226421,\n",
" -11.28890955, -10.05562276])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f_scan_op(samples)[:10]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5.37 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit f_scan_op(samples)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.53 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit [f(s) for s in samples]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"np.testing.assert_almost_equal(\n",
" [f(s) for s in samples],\n",
" f_scan_op(samples))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment