Skip to content

Instantly share code, notes, and snippets.

@rodrigobaron
Created January 21, 2020 02:57
Show Gist options
  • Save rodrigobaron/e0874af5e8e32b18411fa4bb30e49174 to your computer and use it in GitHub Desktop.
Save rodrigobaron/e0874af5e8e32b18411fa4bb30e49174 to your computer and use it in GitHub Desktop.
trax input layer issue
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "trax_custom_resnet.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "mJUEnOnrL6OU",
"colab_type": "code",
"colab": {}
},
"source": [
"! pip install -q -U trax\n",
"! pip install -q tensorflow"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "QptjpUXtI3RK",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 199
},
"outputId": "e6c3cd3e-6944-4c7e-d978-4d825748c9f7"
},
"source": [
"import os\n",
"import os.path as P\n",
"import urllib.request\n",
"import tarfile\n",
"import numpy as np\n",
"import jax.numpy as jnp\n",
"import trax\n",
"import numpy.random as npr\n",
"import math\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
],
"execution_count": 2,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<p style=\"color: red;\">\n",
"The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.<br>\n",
"We recommend you <a href=\"https://www.tensorflow.org/guide/migrate\" target=\"_blank\">upgrade</a> now \n",
"or ensure your notebook will continue to use TensorFlow 1.x via the <code>%tensorflow_version 1.x</code> magic:\n",
"<a href=\"https://colab.research.google.com/notebooks/tensorflow_version.ipynb\" target=\"_blank\">more info</a>.</p>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"WARNING:tensorflow:\n",
"The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
"For more information, please see:\n",
" * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
" * https://github.com/tensorflow/addons\n",
" * https://github.com/tensorflow/io (for I/O related ops)\n",
"If you depend on functionality not listed there, please file an issue.\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "sGwc7YM6OpBL",
"colab_type": "code",
"colab": {}
},
"source": [
"HEIGHT = 32\n",
"WIDTH = 32\n",
"NUM_CHANNELS = 3\n",
"NUM_CLASSES = 10\n",
"DEFAULT_IMAGE_BYTES = HEIGHT * WIDTH * NUM_CHANNELS\n",
"RECORD_BYTES = DEFAULT_IMAGE_BYTES + 1\n",
"\n",
"\n",
"def _one_hot(x, k, dtype=np.float32):\n",
" return np.array(x[:, None] == np.arange(k), dtype)\n",
"\n",
"def cifar10_dataset(data_directory, batch_files, one_hot): \n",
"\n",
" def download(directory):\n",
" url = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'\n",
" filename = 'cifar-10-binary.tar.gz'\n",
" if not P.exists(directory):\n",
" os.makedirs(directory)\n",
" filepath = P.join(directory, filename)\n",
" if P.exists(filepath):\n",
" return\n",
" \n",
" print('Downloading cifar10 data...')\n",
" filepath, _ = urllib.request.urlretrieve(url, filepath)\n",
" statinfo = os.stat(filepath)\n",
" print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')\n",
"\n",
" print('Extracting files...')\n",
" tarfile.open(filepath, 'r:gz').extractall(directory)\n",
" print('Done!')\n",
"\n",
" def parse_file(file):\n",
" raw_data = np.fromfile(file, dtype='uint8')\n",
" raw_data = raw_data.reshape(-1, RECORD_BYTES)\n",
" label = raw_data[:, 0]\n",
" \n",
" image = np.reshape(raw_data[:, 1:RECORD_BYTES], [-1, NUM_CHANNELS, HEIGHT, WIDTH])\n",
" image = np.moveaxis(image, 1, -1)\n",
" \n",
" return image, label\n",
" \n",
" assert type(batch_files) is list\n",
" download(data_directory)\n",
" files = [P.join(data_directory, 'cifar-10-batches-bin', f) for f in batch_files]\n",
" np_arrays = [parse_file(f) for f in files]\n",
" \n",
" images = []\n",
" labels = []\n",
" for na_x, na_y in np_arrays:\n",
" images.append(na_x)\n",
" labels.append(na_y)\n",
"\n",
" images = np.vstack(images)\n",
" labels = np.hstack(labels)\n",
"\n",
" if one_hot:\n",
" labels = _one_hot(labels, 10)\n",
" \n",
" \n",
" return images, labels\n",
"\n",
"def train(data_directory, one_hot=False):\n",
" batch = list(range(1, 6))\n",
" batch_files = ['data_batch_%d.bin' % i for i in batch]\n",
"\n",
" return cifar10_dataset(\n",
" data_directory, \n",
" batch_files,\n",
" one_hot\n",
" )\n",
"\n",
"def test(data_directory, one_hot=False):\n",
" batch_files = ['test_batch.bin']\n",
" return cifar10_dataset(\n",
" data_directory, \n",
" batch_files,\n",
" one_hot\n",
" )"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "wDH3Jl3HOqgL",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
},
"outputId": "56205975-d0a1-4912-cdab-ede0089d8d4b"
},
"source": [
"train_x, train_y = train('.data/', one_hot=True)\n",
"train_images, train_labels = train('.data/', one_hot=True)\n",
"\n",
"image, label = train_images[0], train_labels[0]\n",
"LABELS = [\n",
" 'airplane',\n",
" 'automobile',\n",
" 'bird',\n",
" 'cat' ,\n",
" 'deer',\n",
" 'dog',\n",
" 'frog',\n",
" 'horse',\n",
" 'ship',\n",
" 'truck'\n",
"]\n",
"\n",
"plt.imshow(image)\n",
"plt.title(LABELS[np.argmax(label)])\n",
"plt.show()"
],
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO2de4yc53XenzO3ndn7LsldkkuKFKmL\nJcMSJTOCKruJZTeBrD8iGShcu4BrFEYUFFFroyla1S1qt2gAu6htGG3hgK6VyK0vUXyJhURorKhO\nVV+gmLJk6hqJokiTy+Vyyd3lzuzMzvX0jxnZK/V93l2SuzO0v+cHEDv8zrzfd+ad73zfzPvMOcfc\nHUKIX31SvXZACNEdFOxCJAQFuxAJQcEuREJQsAuREBTsQiQEBXsCMbPrzewZMyua2T/rtT+iO2R6\n7YDoCf8SwPfc/UCvHRHdQ3f2ZLIHwPMhg5mlu+yL6BIK9oRhZv8bwJ0A/quZlczsq2b2BTN71MyW\nAdxpZiNm9mUzmzOzE2b2b80s1RmfNrPPmNk5M3vNzO43MzczfUq8wlGwJwx3fzeA/wvgfncfBFAD\n8A8B/AGAIQDfB/BfAIwA2AfgNwD8IwD/uLOL3wHwXgAHANwK4N5u+i8uHQW7AIDvuPsP3L0FoA7g\nAwD+tbsX3f04gM8A+FDnue8H8Hl3P+XuCwA+1ROPxUWjYBcAcHLV460AsgBOrNp2AsBU5/HONz1/\n9WNxBaNgFwCwOvXxHNp39z2rtl0FYLrzeAbArlW23ZvrmtgoFOziDbh7E8DDAP7AzIbMbA+Afw7g\nf3ae8jCAj5rZlJmNAvhXPXJVXCQKdhHinwJYBnAM7QW7rwJ4sGP7IoDvAjgC4GkAjwJoAGh2301x\nMZiKV4jLwczeC+AP3X3Pmk8WPUV3dnFRmFnBzO42s4yZTQH4BIBv99ovsTa6s4uLwsz6AfwfAG8B\nUAHwFwA+6u5LPXVMrImCXYiEoI/xQiSErv6eOZvNel8+H7Q1m3wxN4Xwp4+08WPlMvw6lo3YMmme\nB2IWPmDnZ+NkEDc1Gvw1xz5vpWM+kk9qLW/xY7X40SwVeQERWq3wa4v5Ht1fxH+LTDKzpSJ+pFP8\n/WTnAAC0Ip+SPXYisDHR/YWZXyyiVF4JHuyygt3M7gLweQBpAP/d3aM/nezL53Hg1rcHbYuL83xc\nKvxGj+f4ZFy1pZ/ato0PUNvW0UFqy6Wzwe2ZvgIdgzSf4vmFRWqrNfhrGxsdobZUsx7cXq1W6ZiV\nlRVqyxfCF2cAaEbUtnKlFNw+MjpMx8D5/mrVGrWlEX5fAH5xGRrk7/PAAD8/slk+H5WIjx67IaTC\n50jsNTc8fPH49Je+yQ/DPYjTSYX8b2gnRdwI4INmduOl7k8Isblcznf22wAcdfdj7l4D8HUA92yM\nW0KIjeZygn0Kb0yCOIVfJEv8HDO7z8wOm9nhRj38EVMIsfls+mq8ux9y94PufjCT5d+thBCby+UE\n+zTemPG0C7/IjBJCXGFczmr8jwFca2ZXox3kH0C74gllZWUFz78QLH2GxXPn6LhxsgBqW/jK6Nbm\nELVZYYLalltcFSg1wyvkbjk6przCV1TLFb5CXm9yqelcRHPMZ8I+Nhp8f2myGgwAfX191FZeWaa2\nRiv8um1lCx2Tiqhy9YiaUMjw86BEVrTnmw06pr+fr8Zbin86NaLWAAAicl55Jfz1Nva1N50Jvy/1\nlQodc8nB7u4NM7sfwF+iLb096O7hSBZC9JzL0tnd/VG0UxyFEFc4+rmsEAlBwS5EQlCwC5EQFOxC\nJISuZr2lABQyRDbiCg/2EIlt7yRPCJnYNk5thZi0EslqqlTDCSMrdS4LeWR/uUIkgSaSCOMtfryR\n8XACUKPO95fLcj8iyYhI5/ibVq2F56re4PPRH9lfZoD7mI+Ma1hYHkxFsugakQy1WKbl4ABPviot\nl6mt3ghLbLGEw+LSheD2VjR7VAiRCBTsQiQEBbsQCUHBLkRCULALkRC6uhpv5shbOAFhaIi7ct3U\nWHD7lgLPnMi2eKml0jxPTmm2+PWvUg77nuJ5MBiOlLnKRFaRFy8U+bjIuzY+FF4RLi7xpJVaJKGl\nQpI0gHhdtUFS2qle44kaqSZ/YdlIQk6TlOICgAxZPq9W+Zhclr+hqRZPoKmWFqgNJIkKAPrIadxo\nccXgwnJYkWlG6gnqzi5EQlCwC5EQFOxCJAQFuxAJQcEuREJQsAuRELoqvWXMMNYXPmQhIq2MkCSI\nbcO85leTtB8CEOljAqQzkUJopI5YtRWRfiI6WSaSjNGsconK0/waffZsuMtMs85fdbHMkzTKTS5T\nDhYi3V2qpP0T+GtOGZeN0n2RTizLXGbtz4Z9zERaK61E6gZW6lx6a0Wadi2WuI+L5fD5UyJSLwCs\n1MPnQC1Sa1B3diESgoJdiISgYBciISjYhUgICnYhEoKCXYiE0F3pLW3YNhqWUIayXPLK58O2VJpL\nHYVIfbd6g8tQrUgmV7sz9f9PLVIvrlnjslzLIxllEcnLMzwrq1gLZ7A1m3x+y5FWU42IrbjM/Z+e\nD/uRTfH9DZf43NfP8PZglQtcOrxq6zXB7RMTu+gYGwrXdwOA6sJ5aiuVePbghSKX3s5dCMusx09y\nP5rpcOhWa1yuu6xgN7PjAIpoS9cNdz94OfsTQmweG3Fnv9Pd+WVXCHFFoO/sQiSEyw12B/BdM3vK\nzO4LPcHM7jOzw2Z2OPZTPiHE5nK5H+Pf6e7TZjYB4DEze8ndn1j9BHc/BOAQAIz05/hKlhBiU7ms\nO7u7T3f+ngXwbQC3bYRTQoiN55Lv7GY2ACDl7sXO498C8B9iY7KZNHZuCxciHM5xyWCwPyw1WUS6\nQiQDySLZZtUKl3FSRJbbMsTbUA0M8GytpQt8XXNkmGeUFSNFIE9Mh/dZqnLpLRf5djXVH8nay/LM\nvOPnw9l3VY8UCY1kvY0MD1HbHTdyEWhpJiyzejlyrK08m7Ja5vNRKvF7Z1+W73P39vBrm5iYpGNm\nl8JS3vmXz9Axl/MxfhLAtzu90TIAvuru/+sy9ieE2EQuOdjd/RiAmzfQFyHEJiLpTYiEoGAXIiEo\n2IVICAp2IRJC17PexofC2WiZWliqAYC+bNjN/r5wXzMAqFa4PFWP9OsaHQ33lQMAJ0UKa01+zazX\nI8UQB3kfuNNz4V5eAPDqCZ4NNVcMv7ZI7ULsifTMu/fvHqC2XTu4/9946lhw+4+Ocmmo0eKZfpkU\nl8qKi3PUVi6F53FoiEthaPLsu3yej8uR7EwA6Dc+rtEMvzlX7d5JxwzNh3sBHnmNz4Xu7EIkBAW7\nEAlBwS5EQlCwC5EQFOxCJITursZnMpgY3xK0Veb5qnXKwm6WSNscAKjEanFZpB5bpE0SuzJW6nwV\neXSMJ7TUmnyF+dip09Q2v8R9ZPXp0pGWUcN5vr+JTHjVFwDy81wxuHZ4e3D7zDj3Y3bxLLVVy3yO\nn375ZWpLkRoK9YFI66oRnoCCFA+ZkRGuDg21Iu2mSJ1Cry3RMXtJQllfls+v7uxCJAQFuxAJQcEu\nREJQsAuREBTsQiQEBbsQCaHL0lsWY1u3BW1jg7xdUyoVTiJYXFqgY+rLJb6/Zqz9Ey/I5iQhZ3CQ\n15mrg9tePMYlo+UqbyWUz/dxWy7sY2GAy0JjaS5TPnV0ltoaNX76VEfC0tu2MT4fBi6H1Rtcmi3X\neC28ZVJrrtbgr9kiUmqkOxiyqUjrsFSk9l4mPI+NKpc2nci2JFcLgO7sQiQGBbsQCUHBLkRCULAL\nkRAU7EIkBAW7EAmhq9IbYACR0SzSHofRF6kH1o9wVhAAZCLXuFQqUk+OyHJ9Bd7+6dwZnjVWPsel\nw33jXKKqchUKeSKxXb9/io5JRXbYSPM5XopIn5l0uE7eUI6/L1vG9lPb/muvorbXfvZjanvp5eng\n9lwmIms5l20bDR4yKZJxCADZHJ/HVit8XrUiOp9Z+DyNKINr39nN7EEzO2tmz63aNm5mj5nZK52/\nvEqjEOKKYD0f4/8YwF1v2vYAgMfd/VoAj3f+L4S4glkz2Dv91ufftPkeAA91Hj8E4N4N9ksIscFc\n6gLdpLvPdB6fQbujaxAzu8/MDpvZ4WI58mVTCLGpXPZqvLc7J9Bf5Lr7IXc/6O4Hh/r5opMQYnO5\n1GCfNbMdAND5y4uHCSGuCC5VensEwIcBfKrz9zvrGdRyR2UlXFzP6jxzCQhnKC0v84J8tTq/jjVS\n/BNGqcylsiVim9rNp9EbfH97tnKhZP9OLtWUV/i4qetuDm7POf8KtXCBF+4sjIYLhAIAzvNMrt3b\ndwS3Ly7zbL59b7mW2obHeNbe8NgN1LYwF57/hQu8hVY2Ig+mnGcc1luRbEqeTIlmPXx+R5LoaCuy\nSNLbuqS3rwH4EYDrzeyUmX0E7SD/TTN7BcDf6/xfCHEFs+ad3d0/SEzv2WBfhBCbiH4uK0RCULAL\nkRAU7EIkBAW7EAmhq1lvDkfTwvKEN3kBQCYzFPK8SOXgEJdqTs9xme+1U3PUlsmG/cjN8r5sK7N8\nf9dOcHntPe/iMtSr02/+9fIvGJoKF/TcuiVcABIAzs7xopKjoxEZqsX9z5ECi2fnwlloAJDJL1Lb\n3OIMtU3P8Cy1bDZ8HowOcy2sUuEClmf4/dEiWlkrIsulLDzOIhmYkTaB/DgXP0QI8cuIgl2IhKBg\nFyIhKNiFSAgKdiESgoJdiITQVektnU5hdHQwaGtkuPRWKoUztrzO5YwLRZ7VdOJnXGoqlbiMU8iH\nr40zr/Hsu8k8L0I4NbWH2kZ3Xk1t2WIkhYoU4dx18218yBkuhxUaXDpsgmfSLS+HbTv6w9IgANSa\n/HXZQPi8AYBdAzupbWg0LDkWz5+hY87Onqe2unG5caXGi1gixbWygb5wFmatEpEUSQFLIzIeoDu7\nEIlBwS5EQlCwC5EQFOxCJAQFuxAJoaur8a1mA8XF8EpnpsZrtWVJqxvwEmjIpLmxXOIr9WNDPPFj\ndCC8alpZ4KvxEzt5Dbepm36D2p47VaO2l49y2x07xoPbFxf5mMn94bp1AJBCmdpqVb5SP+rhlfWl\ns3ylu1DjtfB2jIdfFwAsNnlduOxN4WZFlUhizQ8efYTaTp3krzkdafEUa8zE8m7qsTZl9fBcsaQx\nQHd2IRKDgl2IhKBgFyIhKNiFSAgKdiESgoJdiITQVekNANJEgWhGfvTvRLZIkbZQANA0Lr0tcIUH\nS0uR+mPVsHy1Y4TLdb92553Utuv626ntW3/0ILVtjySFpGvh+nrTx17l+9t3I7Xlt1xDbQPO5dLy\nfLj9X6EVlsIAoFbhMt+5IreNbuNJQ1u27w1ur5SG6ZgUN6GZ48k/sRp09TqXPq0RTugy54lejUY4\ndC9LejOzB83srJk9t2rbJ81s2sye6fy7e639CCF6y3o+xv8xgLsC2z/n7gc6/x7dWLeEEBvNmsHu\n7k8A4LWLhRC/FFzOAt39Znak8zGffhEzs/vM7LCZHS6V+fcWIcTmcqnB/gUA+wEcADAD4DPsie5+\nyN0PuvvBwX5etUUIsblcUrC7+6y7N929BeCLAHjNIyHEFcElSW9mtsPdX08beh+A52LP//k4AEaU\ngSbJ4gF4G5xIJx54JbK/SAm38S28bdT2/rDUd+vB6+iYG+7g8trCWS439jV4Zt6+XbuorUVe3PYJ\nXvutscIlzHIkW67W4OPqlfCp1QSXDV+dPkVtzz53mNruuJ37uGV7OOtwqRiWBgGAdIwCAGzdy2XW\nVqxdUy0ioxFJ98Icb4dVLYadbJFsQ2AdwW5mXwPwLgBbzewUgE8AeJeZHQDgAI4D+N219iOE6C1r\nBru7fzCw+Uub4IsQYhPRz2WFSAgKdiESgoJdiISgYBciIXQ1680daJEMn0qVSwY5kuWVyfACf+kU\nl2Ou2c4zr/IFfv3bu2d3cPvN7+SZbTuuv4nanvnRH1HbVbu5j9vf+jZqy23bH9ye6R+hY8orXAKs\nLPHMttnTJ6ltYTYsozXrPHutMBQu6AkAW7fy9/rk6aepbXLHVHB7oxzJsqzwNk62vEBtTQ9nHAKA\nM80ZQKEv/Npy2/lrXuojmaCRiNadXYiEoGAXIiEo2IVICAp2IRKCgl2IhKBgFyIhdFV6MzNk0+FD\nLkQKCjZXwjJDob9Ax6RTXOqYiGS2nZzhmUb7bw1V5wJ2vS28vQ2X0OrFZWobGeJS2bbrDlDbcibc\nE+35p39Mx1Qr3I+lJT4f56Z/Rm3pZlj6zOf5KTd1dVgmA4CbruOFLxtpnomWTY+Gt+d4VmRmhReV\nLJ+YpjYmKwNAI3JbLZG+hP1b+OuaJD0Es9lIfzjughDiVwkFuxAJQcEuREJQsAuREBTsQiSE7ibC\ntFqoVsIrnf193BXLh1crsyleA82b3FYY5K2hfvsf/Da13fHe9wS3D2+dpGNmj71IbemI/4tFXoNu\n7vjfUtvpYnhF+K//7M/omMECT7hYqfKEke2TXDEYHgqvJL92iifP1CLzMb5zL7Vd97a3UxuafcHN\n84u83l2ZqD8AsFDhPprzc3ilwhO9SqRlk5e4KnBDWGRAi4tQurMLkRQU7EIkBAW7EAlBwS5EQlCw\nC5EQFOxCJIT1dITZDeDLACbR7gBzyN0/b2bjAP4EwF60u8K83915gS4ADkfLSW24Fk8isEZYtmh4\npMVTpOZXvm+Y2g68ncs4fdmwRPXCM7wG2sLpV6mtWuXSSnGBd8k+efQFait5ODko2+THGsxwKXI4\nz5Mxto1x6W1m9kxweyPS5qtc5DLfydd40g3wPLWUSuEaevkMPz8afRPUdr7Bz51CgdfQ6x/iSVuF\nTFgeLJaX6JhGKywBRpS3dd3ZGwB+391vBHA7gN8zsxsBPADgcXe/FsDjnf8LIa5Q1gx2d59x9590\nHhcBvAhgCsA9AB7qPO0hAPdulpNCiMvnor6zm9leALcAeBLA5KpOrmfQ/pgvhLhCWXewm9kggG8C\n+Ji7v+HLhLs7yNcFM7vPzA6b2eHlCq/lLoTYXNYV7GaWRTvQv+Lu3+psnjWzHR37DgDBhtfufsjd\nD7r7wYFCbiN8FkJcAmsGu5kZ2i2aX3T3z64yPQLgw53HHwbwnY13TwixUawn6+0dAD4E4Fkze6az\n7eMAPgXgYTP7CIATAN6/9q4cQFhGazX4R/xMNlwzrhmp+VUDz06aHOF14f7ykT+ntvHJsMQzsSPc\nFgoAamWevZbNhiUXABgc4BJPJsWlsgEiD26fCNcsA4BKkSumhTT38fzcOWqr18LvzVCeS1C1Epfe\nXnn6MLXNvPQytVUbpCVTls9hMza/u7gUiQF+Dqf6uPSZJzLaGPhc3fDWq4PbC/ljdMyawe7u3wfA\ncv7COZ9CiCsO/YJOiISgYBciISjYhUgICnYhEoKCXYiE0NWCk3BDqxVe2M9FMq/yGVKsL8ULA3qk\nJVCrxjOvzp0LZ2sBQGkubCvUeXZSC/x1jY9xOWx05zZqazSr1DZ9OuyjR/KhUil+GtQaXMJMGy9U\nOZAPy6UkgbG9v5gxksXYrHF5M0XOt6UylxtrfUSuAzC0k8/9coG3yiq2uCy3shy+524Z3kfHbCVS\naibL30vd2YVICAp2IRKCgl2IhKBgFyIhKNiFSAgKdiESQnelNxhSFs6iyvfxDB8nGWwDhbC8AwAD\nQ1uprVznGUhbhnjOfYb4UbswS8e0Unx/5SyXmiYnw1lNANCqcRnn+pt2Bbf/8HuP0zE1L1Nb1ri8\nWSnxccND4ay9XIafcmmL9ENb4e/ZazNcRltcDL9nVVumY7Zdx++BU6ORrD3n7/XCOT5XuZWwhDkw\nFclULIezClsR9VJ3diESgoJdiISgYBciISjYhUgICnYhEkJXV+NTBuQy4etLucoTDNKkBVErUh+t\nXOfJDOksT6roy/HV1mw27Eeun7dBGhnmCTln5vgqfnkqvKoOABO7r6G26bPhunBv/bV30DGludPU\nduxl3lppucQTPzLp8PyPjPDaekbqEwLAzDT38WcnIokwfeH5H57kSs628YiPEVXA5vl7PbbAQ21q\nYjy4fdcoPweOvhBOeKpWeJKX7uxCJAQFuxAJQcEuREJQsAuREBTsQiQEBbsQCWFN6c3MdgP4Mtot\nmR3AIXf/vJl9EsDvAJjrPPXj7v5o9GAZw+S28PWlfv48HVdphiWZZZ7LAE/x1lCZSDLG8DBPPsiR\n1kqVZV6DrhCpCYYatx3+4Q+pbd/1XLI7dSosyaQi9fr6+3gtuXRE3iwUuNS0XApLb5UKl0QbkRZg\ngwXuxx23XEdteZKQ00jz2nrNOk9aqZzk0luqmKe2if4harvlureGx4zyLuhPzbwW3N6o89e1Hp29\nAeD33f0nZjYE4Ckze6xj+5y7/+d17EMI0WPW0+ttBsBM53HRzF4EMLXZjgkhNpaL+s5uZnsB3ALg\nyc6m+83siJk9aGa8NaoQouesO9jNbBDANwF8zN2XAHwBwH4AB9C+83+GjLvPzA6b2eGlMv9OJoTY\nXNYV7GaWRTvQv+Lu3wIAd59196a7twB8EcBtobHufsjdD7r7weF+XslDCLG5rBnsZmYAvgTgRXf/\n7KrtO1Y97X0Antt494QQG8V6VuPfAeBDAJ41s2c62z4O4INmdgBtOe44gN9da0e5nOGq3eG7+4hx\n2eLoybAUMjvHs9dqTS7VDA7yl71c5hlUzVYpuD0duWbOz3FJsVjiMslKnfuRdm4bGgwvncyemadj\nTi1zOanlXLKb3MZlSmuFs68WFnm9uL4B/p6NjnDpKpfm81+tEQk2w+XG5SrfX60UaXnV4uOu2b2d\n2nZuD8/jyVNcYj0/F46JRqSF1npW478PIPSORzV1IcSVhX5BJ0RCULALkRAU7EIkBAW7EAlBwS5E\nQuhqwcl0xjA8RjLHiJQAAGMT6bBhgBcNPDfLC1iuRNonZXK82CAb1qrzDLt6k/txocJlqIFIltdK\nmUtllZVwwclaxMdmxOZO5h5AaSnS/mk4XLhzeJgX56xU+P7OnedzNTjIs+8sFb6fWYPLtrkMLzra\nxxVi5HJ8rvZes5faKuWwL0888QIdc+Tls+F9rXA5V3d2IRKCgl2IhKBgFyIhKNiFSAgKdiESgoJd\niITQVenNzJDJhw+ZH+a57uOD4WtSpsJlrWyBZ/8sRfpuocmvf4X8RHhIlh+rWeX90HL93I9shs9H\nOs0lx6qHfanVudzokcw24woVvMYlwCYxZSPZZshxuXFxgUtvlRrvbzYyGpZSM0SSA4BUZO7L4NLW\n7LkitS1EMhyLy+Esxr/665f4sYhKuVKT9CZE4lGwC5EQFOxCJAQFuxAJQcEuREJQsAuRELoqvbVa\nhhIr2JcepOMGB8I6TrbAdaGBSHrSyAiXykpLvBdZaSlcALBUjmS9rXDbUI4XbMyTvnIA0KhyyTGT\nCV+/c5HLeraPZ2uZ8YH9kcKdKWJqNLk0lCtEevCNcrlxfp5LXkUiRQ6P87kvR3rOvXKcFxB96dmT\n1DY5zrMpJ3eR15bi5+lWUoBztshlSN3ZhUgICnYhEoKCXYiEoGAXIiEo2IVICGuuxptZHsATAPo6\nz/+Gu3/CzK4G8HUAWwA8BeBD7h5t01qrAadOhG3VRb56PrQtvIKbL0QSIPjiPsbH+csuLfM6aIuL\nYdvCeZ44scAXb5Fu8VXwlnOlodnkK/xohW2xq7qleCJMOsPnqhJJGnKy6J4lbaEAoFHmLaqakfp0\nzUhyzWIpPI51hQKA+Ygic/wof0MXzy9TW22ZH3D7SLg11A17pugY5uIrZ5bomPXc2asA3u3uN6Pd\nnvkuM7sdwKcBfM7drwGwAOAj69iXEKJHrBns3ub1jobZzj8H8G4A3+hsfwjAvZvioRBiQ1hvf/Z0\np4PrWQCPAXgVwKL7zz+snQLAP3MIIXrOuoLd3ZvufgDALgC3AXjLeg9gZveZ2WEzO3yhxIsdCCE2\nl4tajXf3RQDfA/B3AIya2eurN7sATJMxh9z9oLsfHBmMVNgXQmwqawa7mW0zs9HO4wKA3wTwItpB\n//c7T/swgO9slpNCiMtnPYkwOwA8ZGZptC8OD7v7n5vZCwC+bmb/EcDTAL601o7cMmhmtwZt9dxB\nOq7aCid+pBrhVkcAkB/hctLoNv4JYyzFEzXGy+HEhMV53i5o8RyX1yrLfPqbDS7nwfk1utUI+7hS\n4V+hcrlIvbsM97+4whM1KuQrWzaizg6lwskdANBKcUmpXufz2DcQljDzWV7vbjTHfdyHUWp72828\nDdX1N91MbXuvuSa4/bbbudx46nQpuP0Hr/KYWDPY3f0IgFsC24+h/f1dCPFLgH5BJ0RCULALkRAU\n7EIkBAW7EAlBwS5EQjCPZFdt+MHM5gC8nve2FQDXCbqH/Hgj8uON/LL5scfdt4UMXQ32NxzY7LC7\nc3FdfsgP+bGhfuhjvBAJQcEuRELoZbAf6uGxVyM/3oj8eCO/Mn707Du7EKK76GO8EAlBwS5EQuhJ\nsJvZXWb2t2Z21Mwe6IUPHT+Om9mzZvaMmR3u4nEfNLOzZvbcqm3jZvaYmb3S+TvWIz8+aWbTnTl5\nxszu7oIfu83se2b2gpk9b2Yf7Wzv6pxE/OjqnJhZ3sz+xsx+2vHj33e2X21mT3bi5k/MLJIHHcDd\nu/oPQBrtGnb7AOQA/BTAjd32o+PLcQBbe3DcXwdwK4DnVm37TwAe6Dx+AMCne+THJwH8iy7Pxw4A\nt3YeDwF4GcCN3Z6TiB9dnRMABmCw8zgL4EkAtwN4GMAHOtv/EMA/uZj99uLOfhuAo+5+zNt15r8O\n4J4e+NEz3P0JAG8ukn4P2lV6gS5V6yV+dB13n3H3n3QeF9GuhDSFLs9JxI+u4m02vKJzL4J9CsDq\n3ra9rEzrAL5rZk+Z2X098uF1Jt19pvP4DIDJHvpyv5kd6XzM3/SvE6sxs71oF0t5Ej2ckzf5AXR5\nTjajonPSF+je6e63AngvgN8zs1/vtUNA+8qO9oWoF3wBwH60G4LMAPhMtw5sZoMAvgngY+7+hjpU\n3ZyTgB9dnxO/jIrOjF4E+zSA3av+TyvTbjbuPt35exbAt9HbMluzZrYDADp/z/bCCXef7ZxoLQBf\nRJfmxMyyaAfYV9z9W53NXaZzBz8AAAD2SURBVJ+TkB+9mpPOsS+6ojOjF8H+YwDXdlYWcwA+AOCR\nbjthZgNmNvT6YwC/BeC5+KhN5RG0q/QCPazW+3pwdXgfujAnZmZoFyx90d0/u8rU1TlhfnR7Tjat\nonO3VhjftNp4N9orna8C+Dc98mEf2krATwE8300/AHwN7Y+DdbS/e30E7QaZjwN4BcBfARjvkR//\nA8CzAI6gHWw7uuDHO9H+iH4EwDOdf3d3e04ifnR1TgDchHbF5iNoX1j+3apz9m8AHAXwpwD6Lma/\n+rmsEAkh6Qt0QiQGBbsQCUHBLkRCULALkRAU7EIkBAW7EAlBwS5EQvh/5M/Ej8eN80MAAAAASUVO\nRK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "QJkTEQfjY5P2",
"colab_type": "code",
"colab": {}
},
"source": [
"batch_size = 32\n",
"num_train = train_labels.shape[0]\n",
"num_batches = math.ceil(num_train / batch_size)\n",
"def my_inputs(n_devices):\n",
" rng = npr.RandomState(0)\n",
" while True:\n",
" perm = rng.permutation(num_train)\n",
" for i in range(num_batches):\n",
" batch_idx = perm[i * batch_size:(i + 1) * batch_size]\n",
" batch_images, batch_labels = train_images[batch_idx], train_labels[batch_idx]\n",
" \n",
" batch_images = (batch_images / 255.0).astype('float')\n",
" yield batch_images, batch_labels"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "OsXjmnPx_8sY",
"colab_type": "code",
"colab": {}
},
"source": [
"from trax import layers as tl\n",
"\n",
"\n",
"def ConvBlock(kernel_size, filters, strides, norm, non_linearity,\n",
" mode='train'):\n",
" \"\"\"ResNet convolutional striding block.\"\"\"\n",
" # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant.\n",
" ks = kernel_size\n",
" filters1, filters2, filters3 = filters\n",
" main = [\n",
" tl.Conv(filters1, (1, 1), strides),\n",
" norm(mode=mode),\n",
" non_linearity(),\n",
" tl.Conv(filters2, (ks, ks), padding='SAME'),\n",
" norm(mode=mode),\n",
" non_linearity(),\n",
" tl.Conv(filters3, (1, 1)),\n",
" norm(mode=mode),\n",
" ]\n",
" shortcut = [\n",
" tl.Conv(filters3, (1, 1), strides),\n",
" norm(mode=mode),\n",
" ]\n",
" return [\n",
" tl.Residual(main, shortcut=shortcut),\n",
" non_linearity()\n",
" ]\n",
"\n",
"\n",
"def IdentityBlock(kernel_size, filters, norm, non_linearity,\n",
" mode='train'):\n",
" \"\"\"ResNet identical size block.\"\"\"\n",
" # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant.\n",
" ks = kernel_size\n",
" filters1, filters2, filters3 = filters\n",
" main = [\n",
" tl.Conv(filters1, (1, 1)),\n",
" norm(mode=mode),\n",
" non_linearity(),\n",
" tl.Conv(filters2, (ks, ks), padding='SAME'),\n",
" norm(mode=mode),\n",
" non_linearity(),\n",
" tl.Conv(filters3, (1, 1)),\n",
" norm(mode=mode),\n",
" ]\n",
" return [\n",
" tl.Residual(main),\n",
" non_linearity(),\n",
" ]\n",
"\n",
"\n",
"def Resnet50(d_hidden=64, n_output_classes=10, mode='train',\n",
" norm=tl.BatchNorm,\n",
" non_linearity=tl.Relu):\n",
" \"\"\"ResNet.\n",
" Args:\n",
" d_hidden: Dimensionality of the first hidden layer (multiplied later).\n",
" n_output_classes: Number of distinct output classes.\n",
" mode: Whether we are training or evaluating or doing inference.\n",
" norm: `Layer` used for normalization, Ex: BatchNorm or\n",
" FilterResponseNorm.\n",
" non_linearity: `Layer` used as a non-linearity, Ex: If norm is\n",
" BatchNorm then this is a Relu, otherwise for FilterResponseNorm this\n",
" should be ThresholdedLinearUnit.\n",
" Returns:\n",
" The list of layers comprising a ResNet model with the given parameters.\n",
" \"\"\"\n",
"\n",
" # A ConvBlock configured with the given norm, non-linearity and mode.\n",
" def Resnet50ConvBlock(filter_multiplier=1, strides=(2, 2)):\n",
" filters = (\n",
" [filter_multiplier * dim for dim in [d_hidden, d_hidden, 4 * d_hidden]])\n",
" return ConvBlock(3, filters, strides, norm, non_linearity, mode)\n",
"\n",
" # Same as above for IdentityBlock.\n",
" def Resnet50IdentityBlock(filter_multiplier=1):\n",
" filters = (\n",
" [filter_multiplier * dim for dim in [d_hidden, d_hidden, 4 * d_hidden]])\n",
" return IdentityBlock(3, filters, norm, non_linearity, mode)\n",
"\n",
" return tl.Serial(\n",
" tl.ToFloat(),\n",
" tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'),\n",
" norm(mode=mode),\n",
" non_linearity(),\n",
" tl.MaxPool(pool_size=(3, 3), strides=(2, 2)),\n",
" Resnet50ConvBlock(strides=(1, 1)),\n",
" [Resnet50IdentityBlock() for _ in range(2)],\n",
" Resnet50ConvBlock(2),\n",
" [Resnet50IdentityBlock(2) for _ in range(3)],\n",
" Resnet50ConvBlock(4),\n",
" [Resnet50IdentityBlock(4) for _ in range(5)],\n",
" Resnet50ConvBlock(8),\n",
" [Resnet50IdentityBlock(8) for _ in range(2)],\n",
" tl.AvgPool(pool_size=(1, 1)),\n",
" tl.Flatten(),\n",
" tl.Dense(n_output_classes),\n",
" tl.LogSoftmax(),\n",
" )\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "DJzaj9RcahJq",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 71
},
"outputId": "decd8bb8-6675-461a-893c-b17d666ecbb2"
},
"source": [
"output_dir = os.path.expanduser('~/train_dir/')\n",
"!rm -f ~/train_dir/model.pkl # Remove old model.\n",
"trainer = trax.supervised.Trainer(\n",
" model=Resnet50,\n",
" loss_fn=trax.layers.CrossEntropyLoss,\n",
" optimizer=trax.optimizers.Adam,\n",
" lr_schedule=trax.lr.EvalAdjustingSchedule,\n",
" inputs=trax.supervised.inputs.Inputs(my_inputs),\n",
" output_dir=output_dir,\n",
" has_weights=True)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py:4578: UserWarning: Explicitly requested dtype <class 'numpy.int64'> requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n",
" warnings.warn(msg.format(dtype, fun_name , truncated_dtype))\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "qBW2cKNEgjdL",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 612
},
"outputId": "f5dceb5b-ef08-4d5a-e0de-3a78a8b93f58"
},
"source": [
"n_epochs = 1\n",
"train_steps = num_batches\n",
"eval_steps = 1\n",
"for _ in range(n_epochs):\n",
" trainer.train_epoch(train_steps, eval_steps)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "error",
"ename": "LayerError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/trax/layers/base.py\u001b[0m in \u001b[0;36m_forward_internal\u001b[0;34m(self, x, weights, state, rng)\u001b[0m\n\u001b[1;32m 454\u001b[0m outputs, s = self.forward_with_state(\n\u001b[0;32m--> 455\u001b[0;31m x, weights=weights, state=state, rng=rng)\n\u001b[0m\u001b[1;32m 456\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/trax/layers/combinators.py\u001b[0m in \u001b[0;36mforward_with_state\u001b[0;34m(self, xs, weights, state, **kwargs)\u001b[0m\n\u001b[1;32m 58\u001b[0m state=base.EMPTY_STATE, **kwargs):\n\u001b[0;32m---> 59\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_forward_inputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0mrngs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_pop_rng_and_split\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_layers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/trax/layers/combinators.py\u001b[0m in \u001b[0;36m_validate_forward_inputs\u001b[0;34m(self, xs)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;34m'number of inputs ({}) to Serial.forward less than n_in'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 137\u001b[0;31m ' ({})'.format(len(xs), self.n_in))\n\u001b[0m\u001b[1;32m 138\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: number of inputs (2) to Serial.forward less than n_in (3)",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mLayerError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-8-613c5ccb5042>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0meval_steps\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_epochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/trax/supervised/trainer_lib.py\u001b[0m in \u001b[0;36mtrain_epoch\u001b[0;34m(self, n_steps, n_eval_steps)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_devices\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# TODO(lukaszkaiser): use everywhere if possible.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_reshape_by_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_devices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 308\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 309\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_should_save_now\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 310\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkeep\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/trax/supervised/trainer_lib.py\u001b[0m in \u001b[0;36mtrain_step\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0;31m# Run the update.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 339\u001b[0m (weights, slots), self._model_state, self._rngs = self._jit_update_fn(\n\u001b[0;32m--> 340\u001b[0;31m self._step, opt_state, batch, self._model_state, self._rngs)\n\u001b[0m\u001b[1;32m 341\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_model_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_map_to_state_dicts\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state_dicts_update\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 342\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_opt_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopt_state\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_replace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mslots\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mslots\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/api.py\u001b[0m in \u001b[0;36mf_jitted\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0m_check_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs_flat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 150\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mxla\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mxla_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 151\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, f, *args, **params)\u001b[0m\n\u001b[1;32m 592\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtop_trace\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 593\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mnew_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 594\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 595\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 596\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_call_impl\u001b[0;34m(fun, *args, **params)\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'device'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 360\u001b[0m \u001b[0mbackend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'backend'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 361\u001b[0;31m \u001b[0mcompiled_fun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_xla_callable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mabstractify\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 362\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 363\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcompiled_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mmemoized_fun\u001b[0;34m(fun, *args)\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpopulate_stores\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 209\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 210\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_callable\u001b[0;34m(fun, device, backend, *abstract_args)\u001b[0m\n\u001b[1;32m 373\u001b[0m \u001b[0mpvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPartialVal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0maval\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mabstract_args\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 374\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_master\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mJaxprTrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmaster\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 375\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_subjaxpr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaster\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 376\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0menv\u001b[0m \u001b[0;31m# no subtraces here\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 377\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mmaster\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mgen\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 153\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 154\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mstack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/trax/supervised/trainer_lib.py\u001b[0m in \u001b[0;36msingle_update\u001b[0;34m(i, opt_state, batch, state, rng)\u001b[0m\n\u001b[1;32m 682\u001b[0m \u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msubrng\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax_random\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 683\u001b[0m \u001b[0mgrad_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_and_loss_call\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 684\u001b[0;31m \u001b[0mgrads\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgrad_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrng\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 685\u001b[0m return optimizer.tree_update(\n\u001b[1;32m 686\u001b[0m i, grads, weights, slots, opt_params), state, [subrng]\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/api.py\u001b[0m in \u001b[0;36mgrad_f_aux\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdocstr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdocstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margnums\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margnums\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f_aux\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 352\u001b[0;31m \u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 353\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 402\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 403\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 404\u001b[0;31m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 405\u001b[0m \u001b[0m_check_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 406\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdtypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/api.py\u001b[0m in \u001b[0;36mvjp\u001b[0;34m(fun, *primals, **kwargs)\u001b[0m\n\u001b[1;32m 1259\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1260\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_aux_trees\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun_nokwargs2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1261\u001b[0;31m \u001b[0mout_primal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_vjp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimals_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1262\u001b[0m \u001b[0mout_tree\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout_aux_trees\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1263\u001b[0m \u001b[0mout_primal_py\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_primal\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mvjp\u001b[0;34m(traceable, primals, has_aux)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0mout_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinearize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 109\u001b[0;31m \u001b[0mout_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinearize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 110\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvjp_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mcts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0mcts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mignore_consts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mlinearize\u001b[0;34m(traceable, *primals, **kwargs)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0mpval_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpval_tangents\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0maval_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconst_primals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0munzip2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpval_primals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr\u001b[0;34m(fun, pvals, **kwargs)\u001b[0m\n\u001b[1;32m 341\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mnew_master\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mJaxprTrace\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmaster\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 342\u001b[0m \u001b[0mfun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaster\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minstantiate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 343\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 344\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 345\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mmaster\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mgen\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 153\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 154\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mstack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/trax/supervised/trainer_lib.py\u001b[0m in \u001b[0;36mmodel_and_loss_call\u001b[0;34m(weights, batch, state, rng)\u001b[0m\n\u001b[1;32m 675\u001b[0m \u001b[0;31m# Gradients are always wrt. the first argument, so putting weights first.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 676\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmodel_and_loss_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrng\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 677\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_and_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrng\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 678\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_and_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 679\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mn_devices\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# TODO(lukaszkaiser): remove branch when not needed.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/trax/layers/base.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, x, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mn_accelerators\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0mforward\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjit_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_accelerators\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrng\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mn_accelerators\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mreplicate\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# Unreplicate state if needed.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0mnew_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnested_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/trax/layers/base.py\u001b[0m in \u001b[0;36m_forward_internal\u001b[0;34m(self, x, weights, state, rng)\u001b[0m\n\u001b[1;32m 462\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_short_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 463\u001b[0m raise LayerError(name, '_forward_internal',\n\u001b[0;32m--> 464\u001b[0;31m self._caller, signature(x), trace)\n\u001b[0m\u001b[1;32m 465\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 466\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_forward_abstract\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_signature\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mLayerError\u001b[0m: Exception passing through layer Serial (in _forward_internal):\n layer created in file [...]/trax/supervised/trainer_lib.py, line 674\n layer input shapes: (ShapeDtype{shape:(32, 32, 32, 3), dtype:float32}, ShapeDtype{shape:(32, 10), dtype:float32})\n\n File [...]/trax/layers/combinators.py, line 59, in forward_with_state\n self._validate_forward_inputs(xs)\n\n File [...]/trax/layers/combinators.py, line 137, in _validate_forward_inputs\n ' ({})'.format(len(xs), self.n_in))\n\nValueError: number of inputs (2) to Serial.forward less than n_in (3)"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment