Last active
July 7, 2019 05:45
-
-
Save fehiepsi/43e78b321d9f96ffd036244f1b727799 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from jax import lax, jit\n", | |
"from jax.config import config; config.update('jax_platform_name', 'gpu')\n", | |
"import jax.numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"N, D_X, D_H = 50, 3, 5\n", | |
"X, Y = np.ones((N, D_X)), np.zeros(N)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def potential_fn(w1):\n", | |
" z1 = np.matmul(X, w1)\n", | |
" w2 = np.ones((D_H, D_H))\n", | |
" z2 = np.matmul(z1, w2)\n", | |
" w3 = np.ones((D_H, 1))\n", | |
" z3 = np.matmul(z2, w3)\n", | |
" return np.sum((z3 - Y) ** 2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@jit\n", | |
"def loop(w):\n", | |
" def body_fn(i, state):\n", | |
" w, f = state\n", | |
" f = potential_fn(w)\n", | |
" return w, f\n", | |
"\n", | |
" w, f = lax.fori_loop(0, 1000, body_fn, (w, 0.))\n", | |
" return f" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### CPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 118 ms, sys: 76.3 ms, total: 194 ms\n", | |
"Wall time: 212 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"DeviceArray(14062500., dtype=float32)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"loop(np.ones((D_X, D_H)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 550 µs, sys: 0 ns, total: 550 µs\n", | |
"Wall time: 450 µs\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"DeviceArray(14062500., dtype=float32)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"loop(np.ones((D_X, D_H)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### GPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 429 ms, sys: 199 ms, total: 628 ms\n", | |
"Wall time: 743 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"DeviceArray(14062500., dtype=float32)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"loop(np.ones((D_X, D_H)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 35.7 ms, sys: 0 ns, total: 35.7 ms\n", | |
"Wall time: 34.4 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"DeviceArray(14062500., dtype=float32)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"loop(np.ones((D_X, D_H)))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment