Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Created May 8, 2024 09:56
Show Gist options
  • Save ricardoV94/41e11ad4dded4d881e5edc443339cbb9 to your computer and use it in GitHub Desktop.
Save ricardoV94/41e11ad4dded4d881e5edc443339cbb9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 167,
"id": "d1219243",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import pytensor\n",
"import pytensor.tensor as pt\n",
"import numpy as np\n",
"\n",
"from pytensor.graph.fg import FunctionGraph\n",
"from pytensor.link.jax.dispatch import jax_funcify\n",
"from pytensor.compile.mode import get_mode\n",
"\n",
"from pytensor.graph.rewriting.utils import rewrite_graph"
]
},
{
"cell_type": "code",
"execution_count": 168,
"id": "d1270d7e",
"metadata": {},
"outputs": [],
"source": [
"x = pt.vector(\"x\")\n",
"one_mx = 1 - x\n",
"out = pt.log(one_mx)"
]
},
{
"cell_type": "code",
"execution_count": 169,
"id": "175f2f62",
"metadata": {},
"outputs": [],
"source": [
"fg = FunctionGraph(inputs=None, outputs=[out])"
]
},
{
"cell_type": "code",
"execution_count": 170,
"id": "52d024a1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Log [id A] 2\n",
" └─ Sub [id B] 1\n",
" ├─ ExpandDims{axis=0} [id C] 0\n",
" │ └─ 1 [id D]\n",
" └─ x [id E]\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7efd117736d0>"
]
},
"execution_count": 170,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pytensor.dprint(fg)"
]
},
{
"cell_type": "code",
"execution_count": 171,
"id": "d4133f96",
"metadata": {},
"outputs": [],
"source": [
"opt_fg = rewrite_graph(fg, include=(\"canonicalize\", \"stabilize\", \"specialize\"))"
]
},
{
"cell_type": "code",
"execution_count": 172,
"id": "707fa5d8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Log1p [id A] 1\n",
" └─ Neg [id B] 0\n",
" └─ x [id C]\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7efd117736d0>"
]
},
"execution_count": 172,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pytensor.dprint(opt_fg)"
]
},
{
"cell_type": "code",
"execution_count": 173,
"id": "862815b7",
"metadata": {},
"outputs": [],
"source": [
"jax_fn = jax_funcify(opt_fg)"
]
},
{
"cell_type": "code",
"execution_count": 174,
"id": "177ce00c",
"metadata": {},
"outputs": [],
"source": [
"jax_jitted_fn = jax.jit(jax_fn)"
]
},
{
"cell_type": "code",
"execution_count": 175,
"id": "3a4a6660",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array([-2.30258509, -2.30258509], dtype=float64)"
]
},
"execution_count": 175,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out = jax_jitted_fn(np.array([0.9, 0.9]))[0]\n",
"out"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60f6999c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "pytensor",
"language": "python",
"name": "pytensor"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment