-
-
Save maciejkorzepa/4c4bb6c445dc41449f90df14f04e67a9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"os.environ['CUDA_VISIBLE_DEVICES'] = '5'\n", | |
"os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/opt/cuda/cuda-10.1'\n", | |
"import torch \n", | |
"import numpy as np\n", | |
"import jax\n", | |
"from time import time\n", | |
"from jax.experimental import stax\n", | |
"import neural_tangents as nt\n", | |
"\n", | |
"num_base_out_chan = 32\n", | |
"\n", | |
"init_fn, apply_fn = stax.serial(\n", | |
" stax.Conv(num_base_out_chan, filter_shape=(3, 3), strides=(2, 2), padding='SAME'), stax.Relu,\n", | |
" stax.MaxPool(window_shape=(3, 3), strides=(2, 2), padding='SAME'),\n", | |
" stax.Conv(num_base_out_chan, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n", | |
" stax.Conv(num_base_out_chan, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n", | |
" stax.MaxPool(window_shape=(3, 3), strides=(2, 2), padding='SAME'),\n", | |
" stax.Conv(num_base_out_chan * 2, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n", | |
" stax.Conv(num_base_out_chan * 2, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n", | |
" stax.MaxPool(window_shape=(3, 3), strides=(2, 2), padding='SAME'),\n", | |
" stax.Conv(num_base_out_chan * 4, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n", | |
" stax.Conv(num_base_out_chan * 4, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n", | |
" stax.Conv(num_base_out_chan * 4, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n", | |
" stax.Conv(num_base_out_chan * 4, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n", | |
" stax.MaxPool(window_shape=(3, 3), strides=(2, 2), padding='SAME'),\n", | |
" stax.Conv(num_base_out_chan * 8, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n", | |
" stax.Conv(num_base_out_chan * 8, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n", | |
" stax.Conv(num_base_out_chan * 8, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n", | |
" stax.Conv(num_base_out_chan * 8, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n", | |
" stax.FanOut(2),\n", | |
" stax.parallel(\n", | |
" stax.serial(stax.MaxPool(window_shape=(16, 16)), stax.Flatten), \n", | |
" stax.serial(stax.AvgPool(window_shape=(16, 16)), stax.Flatten)\n", | |
" ),\n", | |
" stax.FanInConcat(),\n", | |
" stax.Dense(1)\n", | |
")\n", | |
"\n", | |
"d = 512\n", | |
"key = jax.random.PRNGKey(0)\n", | |
"_, params = init_fn(key, (-1, d, d, 3))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def compute_ntk(f, x1, x2=None):\n", | |
" t0 = time()\n", | |
" f(x1, x2, params)\n", | |
" print('Time:', time() - t0)\n", | |
" \n", | |
"def gen_data(n):\n", | |
" return np.random.randn(n, d, d, 3).astype(np.float32) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time: 16.312867403030396\n" | |
] | |
} | |
], | |
"source": [ | |
"# First run with jit compilation\n", | |
"ntk_jit = jax.jit(nt.empirical_ntk_fn(apply_fn))\n", | |
"x1_train = gen_data(10)\n", | |
"x2_train = gen_data(80)\n", | |
"compute_ntk(ntk_jit, x1_train, x2_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time: 0.041784048080444336\n", | |
"Time: 1.7398581504821777\n", | |
"Time: 1.8000473976135254\n", | |
"Time: 1.7995421886444092\n", | |
"Time: 1.8012866973876953\n", | |
"Time: 1.8084299564361572\n" | |
] | |
} | |
], | |
"source": [ | |
"# Subsequent runs, same input - gets very slow after first time - why?\n", | |
"for _ in range(6):\n", | |
" compute_ntk(ntk_jit, x1_train, x2_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time: 0.070343017578125\n", | |
"Time: 0.06474590301513672\n", | |
"Time: 0.05652141571044922\n", | |
"Time: 0.05322003364562988\n", | |
"Time: 0.06491804122924805\n", | |
"Time: 0.053679466247558594\n" | |
] | |
} | |
], | |
"source": [ | |
"# Subsequent runs, different input every time - fast each time\n", | |
"for _ in range(6):\n", | |
" x1_train = gen_data(10)\n", | |
" x2_train = gen_data(80)\n", | |
" compute_ntk(ntk_jit, x1_train, x2_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "RuntimeError", | |
"evalue": "Resource exhausted: Out of memory while trying to allocate 22450943496 bytes.", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-6-4e5cb404c834>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mx1_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgen_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m25\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mx2_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgen_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m25\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mcompute_ntk\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mntk_jit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx1_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-2-9f750e74df50>\u001b[0m in \u001b[0;36mcompute_ntk\u001b[0;34m(f, x1, x2)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_ntk\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mt0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Time:'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mt0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/.conda/envs/env_torch/lib/python3.7/site-packages/jax/api.py\u001b[0m in \u001b[0;36mf_jitted\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 152\u001b[0m out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,\n\u001b[0;32m--> 153\u001b[0;31m name=flat_fun.__name__)\n\u001b[0m\u001b[1;32m 154\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/.conda/envs/env_torch/lib/python3.7/site-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, f, *args, **params)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtop_trace\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 977\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mnew_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 978\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 979\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 980\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/.conda/envs/env_torch/lib/python3.7/site-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_call_impl\u001b[0;34m(fun, device, backend, name, *args)\u001b[0m\n\u001b[1;32m 463\u001b[0m \u001b[0mcompiled_fun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_xla_callable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg_spec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 464\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 465\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcompiled_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 466\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 467\u001b[0m print(\"Invalid value encountered in the output of a jit function. \"\n", | |
"\u001b[0;32m~/.conda/envs/env_torch/lib/python3.7/site-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_execute_compiled\u001b[0;34m(compiled, backend, handlers, *args)\u001b[0m\n\u001b[1;32m 569\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompiled\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlocal_devices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 570\u001b[0m \u001b[0minput_bufs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mdevice_put\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0margs\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mtoken\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 571\u001b[0;31m \u001b[0mout_bufs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompiled\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mExecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_bufs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 572\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mFLAGS\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjax_debug_nans\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcheck_nans\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxla_call_p\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_bufs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mhandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_buf\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhandler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_buf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandlers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_bufs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mRuntimeError\u001b[0m: Resource exhausted: Out of memory while trying to allocate 22450943496 bytes." | |
] | |
} | |
], | |
"source": [ | |
"# memory error (on 12GB GPU) with |x1|=25 and |x2|=25 (kernel 25x25). It worked for a bigger 10x80 kernel though\n", | |
"x1_train = gen_data(25)\n", | |
"x2_train = gen_data(25)\n", | |
"compute_ntk(ntk_jit, x1_train, x2_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time: 15.460219383239746\n" | |
] | |
} | |
], | |
"source": [ | |
"# Batched, First run with jit compilation, very slow\n", | |
"ntk_jit_batch = nt.batch(jax.jit(nt.empirical_ntk_fn(apply_fn)), batch_size=10, device_count=1)\n", | |
"x1_train = gen_data(10)\n", | |
"x2_train = gen_data(80) \n", | |
"compute_ntk(ntk_jit_batch, x1_train, x2_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time: 7.792666435241699\n", | |
"Time: 7.78826117515564\n", | |
"Time: 7.770298957824707\n", | |
"Time: 7.78845739364624\n", | |
"Time: 7.806784391403198\n", | |
"Time: 7.857091426849365\n" | |
] | |
} | |
], | |
"source": [ | |
"# Subsequent runs with ntk function jit compiled, still very slow\n", | |
"for _ in range(6):\n", | |
" x1_train = gen_data(10)\n", | |
" x2_train = gen_data(80)\n", | |
" compute_ntk(ntk_jit_batch, x1_train, x2_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment