-
-
Save jstac/e0f5bde81cd59564442ee2406c863f09 to your computer and use it in GitHub Desktop.
jax_dp.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/jstac/e0f5bde81cd59564442ee2406c863f09/jax_dp.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "40d2195a", | |
"metadata": { | |
"id": "40d2195a" | |
}, | |
"source": [ | |
"# Dynamic Programming on the GPU: A Tutorial for Economists\n", | |
"\n", | |
"Authors: [Thomas J Sargent](http://www.tomsargent.com/) and [John Stachurski](https://johnstachurski.net/)\n", | |
"\n", | |
"Parallelization on GPUs is one of the major trends of modern scientific computing.\n", | |
"\n", | |
"In this notebook we examine a simple implementation of dynamic programming on\n", | |
"the GPU using Python and the [Google JAX\n", | |
"library](https://jax.readthedocs.io/en/latest/index.html).\n", | |
"\n", | |
"This notebook is intended for readers who are familiar with the basics of dynamic programming and want to learn about the JAX library and working on the GPU.\n", | |
"\n", | |
"The notebook is part of the [QuantEcon](https://quantecon.org) project.\n", | |
"\n", | |
"From our timing on Google Colab with a Tesla P100 GPU, the JAX based Bellman operator is \n", | |
"\n", | |
"* over 3000 times faster than a Numba-based JIT-compiled single-threaded version, and \n", | |
"* almost 1500 times faster than a NumPy-based multithreaded version.\n", | |
"\n", | |
"(The Numba-based version uses a similar compiler to Julia and, for this kind of problem, runs at a similar speed to compiled C and Fortran code.)\n", | |
"\n", | |
"## How to Use this Notebook\n", | |
"\n", | |
"You can run this notebook in Google colab or some other environment that includes a GPU.\n", | |
"\n", | |
"Before you start, please note:\n", | |
"\n", | |
"* You might like to keep a copy of this notebook in its present form to see the timings that we obtain before you run it.\n", | |
"* We have a subscription to Google colab pro, which offers better GPUs. You might like to consider subscribing as well, if you don't have access to a GPU and want to obtain timings similar to ours.\n", | |
"\n", | |
"## A Savings Problem\n", | |
"\n", | |
"To focus on implementation rather than technical details, we will adopt a very\n", | |
"familiar optimal savings problem as our running example. \n", | |
"\n", | |
"* For more details on this problem see, for example, [this QuantEcon\n", | |
"lecture](https://python.quantecon.org/ifp.html).\n", | |
"* For a more elementary discussion of optimal savings see [this discussion of\n", | |
"cake eating](https://python.quantecon.org/cake_eating_problem.html).\n", | |
"\n", | |
"The problem is to maximize the expected discounted sum\n", | |
"\n", | |
"$$\n", | |
" \\mathbb{E} \\sum_{t \\geq 0} \\beta^t u(c_t)\n", | |
"$$\n", | |
"\n", | |
"subject to \n", | |
"\n", | |
"$$\n", | |
" c_t + a_{t+1} \\leq R a_t + y_t, \n", | |
" \\quad c_t \\geq 0\n", | |
" \\quad a_t \\geq 0\n", | |
"$$\n", | |
"\n", | |
"for all $t \\geq 0$, with $a_0$ and $y_0$ given.\n", | |
"\n", | |
"Here \n", | |
"\n", | |
"* $c_t$ is consumption,\n", | |
"* $a_t$ is assets,\n", | |
"* $R$ is the gross risk-free rate of return, and\n", | |
"* $y_t$ is income.\n", | |
"\n", | |
"The income process follows a Markov chain with transition matrix $P$.\n", | |
"\n", | |
"The Bellman equation is\n", | |
"\n", | |
"$$\n", | |
" v(a, y) = u(Ra + y - a') + \\beta \\sum_{y'} v(a', y') P(y, y'),\n", | |
"$$\n", | |
"\n", | |
"where $v$ is the value function.\n", | |
"\n", | |
"The corresponding Bellman operator is \n", | |
"\n", | |
"$$\n", | |
" Tv(a, y) = u(Ra + y - a') + \\beta \\sum_{y'} v(a', y') P(y, y')\n", | |
"$$\n", | |
"\n", | |
"We solve the dynamic program by value function iteration --- that is, by\n", | |
"iterating with $T$.\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1c4fdca8", | |
"metadata": { | |
"id": "1c4fdca8" | |
}, | |
"source": [ | |
"The next cell supresses some unnecessary NumPy warnings." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "d9ea81d1", | |
"metadata": { | |
"id": "d9ea81d1" | |
}, | |
"outputs": [], | |
"source": [ | |
"import warnings\n", | |
"warnings.filterwarnings('ignore')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a07c177e", | |
"metadata": { | |
"id": "a07c177e" | |
}, | |
"source": [ | |
"First we import some libraries." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install -U quantecon # Install quantecon in case it's missing" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "qpO_8-Ha3F7z", | |
"outputId": "00245cf4-161d-4148-835f-973f99d72e63" | |
}, | |
"id": "qpO_8-Ha3F7z", | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Collecting quantecon\n", | |
" Downloading quantecon-0.5.2-py3-none-any.whl (269 kB)\n", | |
"\u001b[?25l\r\u001b[K |█▏ | 10 kB 29.3 MB/s eta 0:00:01\r\u001b[K |██▍ | 20 kB 18.4 MB/s eta 0:00:01\r\u001b[K |███▋ | 30 kB 14.2 MB/s eta 0:00:01\r\u001b[K |████▉ | 40 kB 13.1 MB/s eta 0:00:01\r\u001b[K |██████ | 51 kB 6.2 MB/s eta 0:00:01\r\u001b[K |███████▎ | 61 kB 7.3 MB/s eta 0:00:01\r\u001b[K |████████▌ | 71 kB 6.7 MB/s eta 0:00:01\r\u001b[K |█████████▊ | 81 kB 7.5 MB/s eta 0:00:01\r\u001b[K |███████████ | 92 kB 8.3 MB/s eta 0:00:01\r\u001b[K |████████████▏ | 102 kB 6.8 MB/s eta 0:00:01\r\u001b[K |█████████████▍ | 112 kB 6.8 MB/s eta 0:00:01\r\u001b[K |██████████████▋ | 122 kB 6.8 MB/s eta 0:00:01\r\u001b[K |███████████████▉ | 133 kB 6.8 MB/s eta 0:00:01\r\u001b[K |█████████████████ | 143 kB 6.8 MB/s eta 0:00:01\r\u001b[K |██████████████████▎ | 153 kB 6.8 MB/s eta 0:00:01\r\u001b[K |███████████████████▌ | 163 kB 6.8 MB/s eta 0:00:01\r\u001b[K |████████████████████▊ | 174 kB 6.8 MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 184 kB 6.8 MB/s eta 0:00:01\r\u001b[K |███████████████████████▏ | 194 kB 6.8 MB/s eta 0:00:01\r\u001b[K |████████████████████████▍ | 204 kB 6.8 MB/s eta 0:00:01\r\u001b[K |█████████████████████████▋ | 215 kB 6.8 MB/s eta 0:00:01\r\u001b[K |██████████████████████████▉ | 225 kB 6.8 MB/s eta 0:00:01\r\u001b[K |████████████████████████████ | 235 kB 6.8 MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▏ | 245 kB 6.8 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▍ | 256 kB 6.8 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▋| 266 kB 6.8 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 269 kB 6.8 MB/s \n", | |
"\u001b[?25hRequirement already satisfied: scipy>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from quantecon) (1.4.1)\n", | |
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from quantecon) (1.21.5)\n", | |
"Requirement already satisfied: numba>=0.38 in /usr/local/lib/python3.7/dist-packages (from quantecon) (0.51.2)\n", | |
"Requirement already satisfied: sympy in /usr/local/lib/python3.7/dist-packages (from quantecon) (1.7.1)\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from quantecon) (2.23.0)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba>=0.38->quantecon) (57.4.0)\n", | |
"Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba>=0.38->quantecon) (0.34.0)\n", | |
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->quantecon) (2.10)\n", | |
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->quantecon) (3.0.4)\n", | |
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->quantecon) (1.24.3)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->quantecon) (2021.10.8)\n", | |
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.7/dist-packages (from sympy->quantecon) (1.2.1)\n", | |
"Installing collected packages: quantecon\n", | |
"Successfully installed quantecon-0.5.2\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "da4873ac", | |
"metadata": { | |
"id": "da4873ac" | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"from numba import njit\n", | |
"import quantecon as qe \n", | |
"import matplotlib.pyplot as plt" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9975d552", | |
"metadata": { | |
"id": "9975d552" | |
}, | |
"source": [ | |
"Next we specify some primitives, including the utility function. We will be lazy and include them as global variables." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "29643599", | |
"metadata": { | |
"id": "29643599" | |
}, | |
"outputs": [], | |
"source": [ | |
"R = 1.1\n", | |
"β = 0.99\n", | |
"γ = 2.5\n", | |
"\n", | |
"def u(c):\n", | |
" return c**(1-γ) / (1-γ)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "884bac7b", | |
"metadata": { | |
"id": "884bac7b" | |
}, | |
"source": [ | |
"Now we define the asset grid. We will use a large one so to generate a computationally demanding problem. (Economists usually use smaller grids, but large grids easily arise once we start introducing more features to the model. Since this complicates our code, we'll stick to a simple model with a large grid.)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "4e10b3ab", | |
"metadata": { | |
"id": "4e10b3ab" | |
}, | |
"outputs": [], | |
"source": [ | |
"a_min, a_max = 0.01, 2\n", | |
"a_size = ap_size = 1000\n", | |
"a_grid = np.linspace(a_min, a_max, a_size) # grid for a\n", | |
"ap_grid = np.copy(a_grid) # grid for a'" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "021394a8", | |
"metadata": { | |
"id": "021394a8" | |
}, | |
"source": [ | |
"Next we build the Markov chain for income. We will use QuantEcon's `tauchen()` function to construct a Markov chain via discretization of an AR1 process. The details do not matter much. All we are doing is setting up a grid of possible values for $y_t$ and the matrix of transition probabilities." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "f22b1210", | |
"metadata": { | |
"id": "f22b1210" | |
}, | |
"outputs": [], | |
"source": [ | |
"ρ = 0.9\n", | |
"σ = 0.1\n", | |
"y_size = 100\n", | |
"mc = qe.tauchen(ρ, σ, n=y_size)\n", | |
"y_grid = np.exp(mc.state_values)\n", | |
"P = mc.P" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "33f21f47", | |
"metadata": { | |
"id": "33f21f47" | |
}, | |
"source": [ | |
"## A First Pass: Using Loops and Numba\n", | |
"\n", | |
"As our first implementation of the Bellman operator, we are going to use loops over the state and choice variables. The use of `njit` in the code below indicates that we are using Numba to just-in-time (JIT) compile the utility function and the Bellman operator. This makes the loops inside the Bellman operator run at the same speed as compiled C or Fortran code.\n", | |
"\n", | |
"We are applying Numba's JIT functionality so that we have a serious --- but not parallelized --- benchmark, running on the CPU.\n", | |
"\n", | |
"Below we will compare this benchmark to implementations on the GPU." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "2351774b", | |
"metadata": { | |
"id": "2351774b" | |
}, | |
"outputs": [], | |
"source": [ | |
"u_jit = njit(u) # Compile the utility function\n", | |
"\n", | |
"@njit\n", | |
"def T(v):\n", | |
" \"The Bellman operator.\"\n", | |
" # Allocate memory\n", | |
" v_new = np.empty_like(v)\n", | |
" # Step through all states\n", | |
" for i, a in enumerate(a_grid):\n", | |
" for j, y in enumerate(y_grid):\n", | |
" # Choose a' optimally by stepping through all possible values\n", | |
" v_max = - np.inf\n", | |
" for k, ap in enumerate(ap_grid):\n", | |
" c = R * a + y - ap\n", | |
" if c > 0: \n", | |
" # Calculate the right hand side of the Belllman operator\n", | |
" val = u_jit(c) + β * np.dot(v[k, :], P[j, :])\n", | |
" if val > v_max:\n", | |
" v_max = val\n", | |
" v_new[i, j] = v_max\n", | |
" return v_new" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1c958745", | |
"metadata": { | |
"id": "1c958745" | |
}, | |
"source": [ | |
"Here's a vector to test our operator on." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "dda13b27", | |
"metadata": { | |
"id": "dda13b27" | |
}, | |
"outputs": [], | |
"source": [ | |
"vz = np.zeros((a_size, y_size))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8afe89b4", | |
"metadata": { | |
"id": "8afe89b4" | |
}, | |
"source": [ | |
"Now let's apply the operator and see how long it takes." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "8dfa62a8", | |
"metadata": { | |
"scrolled": true, | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "8dfa62a8", | |
"outputId": "21d8e579-19ac-41cb-c815-3394d8a45c84" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"CPU times: user 12.7 s, sys: 74.7 ms, total: 12.8 s\n", | |
"Wall time: 14.9 s\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[-1.86623555, -1.82779165, -1.79013867, ..., -0.24736292,\n", | |
" -0.24225994, -0.2372622 ],\n", | |
" [-1.85411787, -1.81608627, -1.77883158, ..., -0.2469437 ,\n", | |
" -0.24185503, -0.23687111],\n", | |
" [-1.84213077, -1.8045053 , -1.76764303, ..., -0.24652566,\n", | |
" -0.24145124, -0.23648109],\n", | |
" ...,\n", | |
" [-0.15126798, -0.15067609, -0.15007985, ..., -0.07968264,\n", | |
" -0.07890307, -0.07812548],\n", | |
" [-0.15108321, -0.15049252, -0.14989749, ..., -0.07961914,\n", | |
" -0.0788406 , -0.07806403],\n", | |
" [-0.15089881, -0.15030933, -0.1497155 , ..., -0.07955571,\n", | |
" -0.07877821, -0.07800266]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 9 | |
} | |
], | |
"source": [ | |
"%time T(vz)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Switching to the GPU via JaX" | |
], | |
"metadata": { | |
"id": "8TH7cD0l4izg" | |
}, | |
"id": "8TH7cD0l4izg" | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Next we look to switch to a GPU-based implementation. First, let's check that Google Colab has assigned us a nice GPU." | |
], | |
"metadata": { | |
"id": "Be5AJYm942XE" | |
}, | |
"id": "Be5AJYm942XE" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!nvidia-smi" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "lHDzQN__3rMC", | |
"outputId": "b7fb05f5-907a-4342-82fd-a3384d237258" | |
}, | |
"id": "lHDzQN__3rMC", | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Mon Mar 14 05:31:49 2022 \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", | |
"|-------------------------------+----------------------+----------------------+\n", | |
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", | |
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", | |
"| | | MIG M. |\n", | |
"|===============================+======================+======================|\n", | |
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n", | |
"| N/A 31C P0 26W / 250W | 0MiB / 16280MiB | 0% Default |\n", | |
"| | | N/A |\n", | |
"+-------------------------------+----------------------+----------------------+\n", | |
" \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| Processes: |\n", | |
"| GPU GI CI PID Type Process name GPU Memory |\n", | |
"| ID ID Usage |\n", | |
"|=============================================================================|\n", | |
"| No running processes found |\n", | |
"+-----------------------------------------------------------------------------+\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"When we ran this on Colab, we obtained a Tesla P100, which can be seen in the output above (assuming you are reading this without running it, or that you have been assigned the same GPU).\n", | |
"\n", | |
"Next let's try to set up a Bellman operator that runs on this GPU." | |
], | |
"metadata": { | |
"id": "ynVq7qVg5KzU" | |
}, | |
"id": "ynVq7qVg5KzU" | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "844da4c6", | |
"metadata": { | |
"id": "844da4c6" | |
}, | |
"source": [ | |
"### Step One: Vectorization via NumPy\n", | |
"\n", | |
"JAX prefers vectorized operations, meaning that loops need to be replaced by operations on arrays. We use some NumPy [broadcasting](https://jakevdp.github.io/PythonDataScienceHandbook/02.05-computation-on-arrays-broadcasting.html) tricks to eliminate these loops.\n", | |
"\n", | |
"The basic idea is to add dimensions to arrays so that they will be stretched along the new dimensions when placed in arithmetic operations with other arrays that have more elements along those dimensions. This stretching is done by repeating values, which is what we use to replace loops.\n", | |
"\n", | |
"The next code cell reshapes all arrays to be three-dimensional." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "43fbb3d6", | |
"metadata": { | |
"id": "43fbb3d6" | |
}, | |
"outputs": [], | |
"source": [ | |
"P = np.reshape(P, (y_size, y_size, 1))\n", | |
"a = np.reshape(a_grid, (a_size, 1, 1))\n", | |
"y = np.reshape(y_grid, (1, y_size, 1))\n", | |
"ap = np.reshape(ap_grid, (1, 1, ap_size))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9a709eb2", | |
"metadata": { | |
"id": "9a709eb2" | |
}, | |
"source": [ | |
"Now we can implement a vectorized version of the Bellman operator, which calculates the same values." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "59455d06", | |
"metadata": { | |
"id": "59455d06" | |
}, | |
"outputs": [], | |
"source": [ | |
"def T_vec(v):\n", | |
" vp = np.dot(v, P)\n", | |
" c = R * a + y - ap\n", | |
" m = np.where(c > 0, u(c) + β * vp, -np.inf)\n", | |
" return np.max(m, axis=2)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1d0fe4bf", | |
"metadata": { | |
"id": "1d0fe4bf" | |
}, | |
"source": [ | |
"At this point, everything is in NumPy, and runs **on the CPU** rather than the GPU.\n", | |
"\n", | |
"Let's check the output and see how fast it runs." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "4f5b0e11", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "4f5b0e11", | |
"outputId": "28eae9e0-180f-476f-8f80-7f9336a06a76" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"CPU times: user 6 s, sys: 447 ms, total: 6.44 s\n", | |
"Wall time: 6.42 s\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[-1.86623555, -1.82779165, -1.79013867, ..., -0.24736292,\n", | |
" -0.24225994, -0.2372622 ],\n", | |
" [-1.85411787, -1.81608627, -1.77883158, ..., -0.2469437 ,\n", | |
" -0.24185503, -0.23687111],\n", | |
" [-1.84213077, -1.8045053 , -1.76764303, ..., -0.24652566,\n", | |
" -0.24145124, -0.23648109],\n", | |
" ...,\n", | |
" [-0.15126798, -0.15067609, -0.15007985, ..., -0.07968264,\n", | |
" -0.07890307, -0.07812548],\n", | |
" [-0.15108321, -0.15049252, -0.14989749, ..., -0.07961914,\n", | |
" -0.0788406 , -0.07806403],\n", | |
" [-0.15089881, -0.15030933, -0.1497155 , ..., -0.07955571,\n", | |
" -0.07877821, -0.07800266]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 13 | |
} | |
], | |
"source": [ | |
"%time T_vec(vz)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c6e1efe0", | |
"metadata": { | |
"id": "c6e1efe0" | |
}, | |
"source": [ | |
"The output is the same as above, but execution speed is up by one order of magnitude --- at least on this machine, at the time of writing. \n", | |
"\n", | |
"Where does the speed gain come from, given that we had already compiled our loops in the previous version of $T$?\n", | |
"\n", | |
"The answer is that NumPy array operations use some degree of multithreading on the CPU with basic array operations. So we are running operations at a similar speed but making better use of multi-core CPU platforms via parallelization." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Step Two: Switching to JAX" | |
], | |
"metadata": { | |
"id": "g8SE0WVS5r38" | |
}, | |
"id": "g8SE0WVS5r38" | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "86087543", | |
"metadata": { | |
"id": "86087543" | |
}, | |
"source": [ | |
"Now we are ready for our JAX implementation, which runs on the GPU when available.\n", | |
"\n", | |
"Fortunately, the JAX operations are essentially identical to the NumPy ones, after shifting our arrays to the GPU (the \"device\") and replacing `numpy` calls with `jax.numpy` calls (aliased as `jnp`)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "1c13f56d", | |
"metadata": { | |
"id": "1c13f56d" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Shift all NumPy arrays onto the GPU\n", | |
"P = jax.device_put(P)\n", | |
"a = jax.device_put(a)\n", | |
"y = jax.device_put(y)\n", | |
"ap = jax.device_put(ap)\n", | |
"vz = jax.device_put(vz)\n", | |
"\n", | |
"# Define the Bellman operator as in the NumPy version, but replacing np with jnp\n", | |
"def T_jax(v):\n", | |
" vp = jnp.dot(v, P)\n", | |
" c = R * a + y - ap\n", | |
" m = jnp.where(c > 0, u(c) + β * vp, -np.inf)\n", | |
" return jnp.max(m, axis=2)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Let's look at the timing. (We use `block_until_ready()` only to force evaluation at the time of function call, so we can do proper benchmarking.)" | |
], | |
"metadata": { | |
"id": "JzNNIIIu5xjV" | |
}, | |
"id": "JzNNIIIu5xjV" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "6cdba533", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "6cdba533", | |
"outputId": "033d9455-16b9-48ec-d668-fe2a255b81be" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"CPU times: user 476 ms, sys: 269 ms, total: 745 ms\n", | |
"Wall time: 2.8 s\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([[-1.8662355 , -1.8277917 , -1.7901386 , ..., -0.24736291,\n", | |
" -0.24225995, -0.23726219],\n", | |
" [-1.8541178 , -1.8160865 , -1.7788315 , ..., -0.24694368,\n", | |
" -0.24185503, -0.23687112],\n", | |
" [-1.8421307 , -1.8045052 , -1.7676427 , ..., -0.24652565,\n", | |
" -0.24145123, -0.2364811 ],\n", | |
" ...,\n", | |
" [-0.15126799, -0.15067609, -0.15007983, ..., -0.07968266,\n", | |
" -0.07890309, -0.07812549],\n", | |
" [-0.15108322, -0.15049253, -0.14989749, ..., -0.07961915,\n", | |
" -0.07884061, -0.07806404],\n", | |
" [-0.15089881, -0.15030932, -0.14971548, ..., -0.07955572,\n", | |
" -0.07877821, -0.07800266]], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 15 | |
} | |
], | |
"source": [ | |
"%time T_jax(vz).block_until_ready()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4a9c44ab", | |
"metadata": { | |
"id": "4a9c44ab" | |
}, | |
"source": [ | |
"We already have some speed gain from shifting to the GPU. But we can do even better, using JAX's just-in-time compiler. First we target `T_jax` for compilation." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "336a5d23", | |
"metadata": { | |
"id": "336a5d23" | |
}, | |
"outputs": [], | |
"source": [ | |
"T_jax_jit = jax.jit(T_jax)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bdbacc5c", | |
"metadata": { | |
"id": "bdbacc5c" | |
}, | |
"source": [ | |
"When we first run the function there is only moderate speed gain because the function needs to be compiled before it is run:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "7b1c7538", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "7b1c7538", | |
"outputId": "3159126f-4a6c-4893-9324-ef5890ba7614" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"CPU times: user 159 ms, sys: 1.89 ms, total: 161 ms\n", | |
"Wall time: 280 ms\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([[-1.8662355 , -1.8277917 , -1.7901386 , ..., -0.24736291,\n", | |
" -0.24225995, -0.23726219],\n", | |
" [-1.8541178 , -1.8160865 , -1.7788315 , ..., -0.24694368,\n", | |
" -0.24185503, -0.23687112],\n", | |
" [-1.8421307 , -1.8045052 , -1.7676427 , ..., -0.24652565,\n", | |
" -0.24145123, -0.2364811 ],\n", | |
" ...,\n", | |
" [-0.15126799, -0.15067609, -0.15007983, ..., -0.07968266,\n", | |
" -0.07890309, -0.07812549],\n", | |
" [-0.15108322, -0.15049253, -0.14989749, ..., -0.07961915,\n", | |
" -0.07884061, -0.07806404],\n", | |
" [-0.15089881, -0.15030932, -0.14971548, ..., -0.07955572,\n", | |
" -0.07877821, -0.07800266]], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 17 | |
} | |
], | |
"source": [ | |
"%time T_jax_jit(vz).block_until_ready()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "aae0ffc2", | |
"metadata": { | |
"id": "aae0ffc2" | |
}, | |
"source": [ | |
"But the next time we run it we get a large speed gain:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "361234e8", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "361234e8", | |
"outputId": "e37f75d8-e34d-43a7-d633-0167f6fb2e13" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"CPU times: user 785 µs, sys: 929 µs, total: 1.71 ms\n", | |
"Wall time: 4.54 ms\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([[-1.8662355 , -1.8277917 , -1.7901386 , ..., -0.24736291,\n", | |
" -0.24225995, -0.23726219],\n", | |
" [-1.8541178 , -1.8160865 , -1.7788315 , ..., -0.24694368,\n", | |
" -0.24185503, -0.23687112],\n", | |
" [-1.8421307 , -1.8045052 , -1.7676427 , ..., -0.24652565,\n", | |
" -0.24145123, -0.2364811 ],\n", | |
" ...,\n", | |
" [-0.15126799, -0.15067609, -0.15007983, ..., -0.07968266,\n", | |
" -0.07890309, -0.07812549],\n", | |
" [-0.15108322, -0.15049253, -0.14989749, ..., -0.07961915,\n", | |
" -0.07884061, -0.07806404],\n", | |
" [-0.15089881, -0.15030932, -0.14971548, ..., -0.07955572,\n", | |
" -0.07877821, -0.07800266]], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 18 | |
} | |
], | |
"source": [ | |
"%time T_jax_jit(vz).block_until_ready()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c2533099", | |
"metadata": { | |
"id": "c2533099" | |
}, | |
"source": [ | |
"That's seriously fast.\n", | |
"\n", | |
"This new speed gain is possible because JAX's JIT compiler \"fuses\" the array operations inside `T_jax`, which essentially means that it views them as a whole and optimizes accordingly. This allows generation of highly efficient code for the GPU." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cf524486", | |
"metadata": { | |
"id": "cf524486" | |
}, | |
"source": [ | |
"To finish the exercise off, let's iterate until convergence and then plot the value function." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "1c7808e3", | |
"metadata": { | |
"id": "1c7808e3" | |
}, | |
"outputs": [], | |
"source": [ | |
"def vfi_iterator(v_init=vz, tol=1e-6, max_iter=50_000):\n", | |
" error = tol + 1\n", | |
" i = 0\n", | |
" v = v_init\n", | |
" while error > tol and i < max_iter:\n", | |
" new_v = T_jax_jit(v)\n", | |
" error = jnp.max(jnp.abs(new_v - v))\n", | |
" v = new_v\n", | |
"\n", | |
" if i % 100 == 0:\n", | |
" print(f\"Iteration {i}\")\n", | |
" i += 1\n", | |
"\n", | |
" if i == max_iter:\n", | |
" print(f\"Warning: iteration hit upper bound {max_iter}.\")\n", | |
" else:\n", | |
" print(f\"\\nConverged at iteration {i}.\")\n", | |
" return v" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "5ee0d51d", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "5ee0d51d", | |
"outputId": "13f30192-f0fe-4dda-ea17-a3f8cecb3d46" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Iteration 0\n", | |
"Iteration 100\n", | |
"Iteration 200\n", | |
"Iteration 300\n", | |
"Iteration 400\n", | |
"Iteration 500\n", | |
"Iteration 600\n", | |
"Iteration 700\n", | |
"Iteration 800\n", | |
"Iteration 900\n", | |
"Iteration 1000\n", | |
"Iteration 1100\n", | |
"Iteration 1200\n", | |
"Iteration 1300\n", | |
"Iteration 1400\n", | |
"\n", | |
"Converged at iteration 1462.\n" | |
] | |
} | |
], | |
"source": [ | |
"v = vfi_iterator()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Here's a plot of the value function." | |
], | |
"metadata": { | |
"id": "lN9phwnJ-ZPo" | |
}, | |
"id": "lN9phwnJ-ZPo" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "4bbfd9f4", | |
"metadata": { | |
"id": "4bbfd9f4" | |
}, | |
"outputs": [], | |
"source": [ | |
"from mpl_toolkits.mplot3d.axes3d import Axes3D\n", | |
"from matplotlib import cm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "8ebe4c5b", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 357 | |
}, | |
"id": "8ebe4c5b", | |
"outputId": "f0e02ea5-2dd7-408b-8ca7-2cc7886c1483" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 720x432 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
], | |
"source": [ | |
"\n", | |
"a, y = np.meshgrid(a_grid, y_grid)\n", | |
"\n", | |
"fig = plt.figure(figsize=(10, 6))\n", | |
"ax = fig.add_subplot(111, projection='3d')\n", | |
"ax.plot_surface(a,\n", | |
" y,\n", | |
" v.T,\n", | |
" rstride=2, cstride=2,\n", | |
" cmap=cm.jet,\n", | |
" alpha=0.7,\n", | |
" linewidth=0.25)\n", | |
"\n", | |
"ax.view_init(15, 120)\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "2ff4a41c", | |
"metadata": { | |
"id": "2ff4a41c" | |
}, | |
"outputs": [], | |
"source": [ | |
"" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
}, | |
"colab": { | |
"name": "jax_dp.ipynb", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"accelerator": "GPU" | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment