Skip to content

Instantly share code, notes, and snippets.

@maciejkorzepa
Created April 21, 2020 12:33
Show Gist options
  • Save maciejkorzepa/4c4bb6c445dc41449f90df14f04e67a9 to your computer and use it in GitHub Desktop.
Save maciejkorzepa/4c4bb6c445dc41449f90df14f04e67a9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '5'\n",
"os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/opt/cuda/cuda-10.1'\n",
"import torch \n",
"import numpy as np\n",
"import jax\n",
"from time import time\n",
"from jax.experimental import stax\n",
"import neural_tangents as nt\n",
"\n",
"num_base_out_chan = 32\n",
"\n",
"init_fn, apply_fn = stax.serial(\n",
" stax.Conv(num_base_out_chan, filter_shape=(3, 3), strides=(2, 2), padding='SAME'), stax.Relu,\n",
" stax.MaxPool(window_shape=(3, 3), strides=(2, 2), padding='SAME'),\n",
" stax.Conv(num_base_out_chan, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n",
" stax.Conv(num_base_out_chan, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n",
" stax.MaxPool(window_shape=(3, 3), strides=(2, 2), padding='SAME'),\n",
" stax.Conv(num_base_out_chan * 2, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n",
" stax.Conv(num_base_out_chan * 2, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n",
" stax.MaxPool(window_shape=(3, 3), strides=(2, 2), padding='SAME'),\n",
" stax.Conv(num_base_out_chan * 4, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n",
" stax.Conv(num_base_out_chan * 4, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n",
" stax.Conv(num_base_out_chan * 4, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n",
" stax.Conv(num_base_out_chan * 4, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n",
" stax.MaxPool(window_shape=(3, 3), strides=(2, 2), padding='SAME'),\n",
" stax.Conv(num_base_out_chan * 8, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n",
" stax.Conv(num_base_out_chan * 8, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n",
" stax.Conv(num_base_out_chan * 8, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n",
" stax.Conv(num_base_out_chan * 8, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,\n",
" stax.FanOut(2),\n",
" stax.parallel(\n",
" stax.serial(stax.MaxPool(window_shape=(16, 16)), stax.Flatten), \n",
" stax.serial(stax.AvgPool(window_shape=(16, 16)), stax.Flatten)\n",
" ),\n",
" stax.FanInConcat(),\n",
" stax.Dense(1)\n",
")\n",
"\n",
"d = 512\n",
"key = jax.random.PRNGKey(0)\n",
"_, params = init_fn(key, (-1, d, d, 3))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def compute_ntk(f, x1, x2=None):\n",
" t0 = time()\n",
" f(x1, x2, params)\n",
" print('Time:', time() - t0)\n",
" \n",
"def gen_data(n):\n",
" return np.random.randn(n, d, d, 3).astype(np.float32) "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time: 16.312867403030396\n"
]
}
],
"source": [
"# First run with jit compilation\n",
"ntk_jit = jax.jit(nt.empirical_ntk_fn(apply_fn))\n",
"x1_train = gen_data(10)\n",
"x2_train = gen_data(80)\n",
"compute_ntk(ntk_jit, x1_train, x2_train)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time: 0.041784048080444336\n",
"Time: 1.7398581504821777\n",
"Time: 1.8000473976135254\n",
"Time: 1.7995421886444092\n",
"Time: 1.8012866973876953\n",
"Time: 1.8084299564361572\n"
]
}
],
"source": [
"# Subsequent runs, same input - gets very slow after first time - why?\n",
"for _ in range(6):\n",
" compute_ntk(ntk_jit, x1_train, x2_train)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time: 0.070343017578125\n",
"Time: 0.06474590301513672\n",
"Time: 0.05652141571044922\n",
"Time: 0.05322003364562988\n",
"Time: 0.06491804122924805\n",
"Time: 0.053679466247558594\n"
]
}
],
"source": [
"# Subsequent runs, different input every time - fast each time\n",
"for _ in range(6):\n",
" x1_train = gen_data(10)\n",
" x2_train = gen_data(80)\n",
" compute_ntk(ntk_jit, x1_train, x2_train)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Resource exhausted: Out of memory while trying to allocate 22450943496 bytes.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-6-4e5cb404c834>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mx1_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgen_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m25\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mx2_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgen_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m25\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mcompute_ntk\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mntk_jit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx1_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-2-9f750e74df50>\u001b[0m in \u001b[0;36mcompute_ntk\u001b[0;34m(f, x1, x2)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_ntk\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mt0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Time:'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mt0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.conda/envs/env_torch/lib/python3.7/site-packages/jax/api.py\u001b[0m in \u001b[0;36mf_jitted\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 152\u001b[0m out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,\n\u001b[0;32m--> 153\u001b[0;31m name=flat_fun.__name__)\n\u001b[0m\u001b[1;32m 154\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.conda/envs/env_torch/lib/python3.7/site-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, f, *args, **params)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtop_trace\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 977\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mnew_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 978\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 979\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 980\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.conda/envs/env_torch/lib/python3.7/site-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_call_impl\u001b[0;34m(fun, device, backend, name, *args)\u001b[0m\n\u001b[1;32m 463\u001b[0m \u001b[0mcompiled_fun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_xla_callable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg_spec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 464\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 465\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcompiled_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 466\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 467\u001b[0m print(\"Invalid value encountered in the output of a jit function. \"\n",
"\u001b[0;32m~/.conda/envs/env_torch/lib/python3.7/site-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_execute_compiled\u001b[0;34m(compiled, backend, handlers, *args)\u001b[0m\n\u001b[1;32m 569\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompiled\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlocal_devices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 570\u001b[0m \u001b[0minput_bufs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mdevice_put\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0margs\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mtoken\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 571\u001b[0;31m \u001b[0mout_bufs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompiled\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mExecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_bufs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 572\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mFLAGS\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjax_debug_nans\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcheck_nans\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxla_call_p\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_bufs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mhandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_buf\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhandler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_buf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandlers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_bufs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Resource exhausted: Out of memory while trying to allocate 22450943496 bytes."
]
}
],
"source": [
"# memory error (on 12GB GPU) with |x1|=25 and |x2|=25 (kernel 25x25). It worked for a bigger 10x80 kernel though\n",
"x1_train = gen_data(25)\n",
"x2_train = gen_data(25)\n",
"compute_ntk(ntk_jit, x1_train, x2_train)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time: 15.460219383239746\n"
]
}
],
"source": [
"# Batched, First run with jit compilation, very slow\n",
"ntk_jit_batch = nt.batch(jax.jit(nt.empirical_ntk_fn(apply_fn)), batch_size=10, device_count=1)\n",
"x1_train = gen_data(10)\n",
"x2_train = gen_data(80) \n",
"compute_ntk(ntk_jit_batch, x1_train, x2_train)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time: 7.792666435241699\n",
"Time: 7.78826117515564\n",
"Time: 7.770298957824707\n",
"Time: 7.78845739364624\n",
"Time: 7.806784391403198\n",
"Time: 7.857091426849365\n"
]
}
],
"source": [
"# Subsequent runs with ntk function jit compiled, still very slow\n",
"for _ in range(6):\n",
" x1_train = gen_data(10)\n",
" x2_train = gen_data(80)\n",
" compute_ntk(ntk_jit_batch, x1_train, x2_train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment