Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Last active August 1, 2021 08:52
Show Gist options
  • Save justheuristic/1118a14a798b2b6d47789f7e6f511abd to your computer and use it in GitHub Desktop.
Save justheuristic/1118a14a798b2b6d47789f7e6f511abd to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jheuristic/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import tfnn\n",
"from tfnn.layers import Dense\n",
"from concrete_gate import ConcreteGate\n",
"\n",
"class MyAutoencoder:\n",
" def __init__(self, name, inp_size, hid_size, **kwargs):\n",
" with tf.variable_scope(name):\n",
" self.first = Dense('first', inp_size, hid_size, activ=tf.tanh)\n",
" self.gate = ConcreteGate('gate', shape=[1, hid_size], **kwargs)\n",
" self.second = Dense('second', hid_size, inp_size, activ=lambda x:x)\n",
" \n",
" def __call__(self, x):\n",
" return self.second(self.gate(self.first(x)))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: CUDA_VISIBLE_DEVICES=\n"
]
}
],
"source": [
"%env CUDA_VISIBLE_DEVICES=\n",
"tf.reset_default_graph()\n",
"sess = tf.InteractiveSession()\n",
"\n",
"my_net = MyAutoencoder('network', 64, 100, l0_penalty=1e-3, hard=False)\n",
"x = tf.random_normal([5, 64])\n",
"x_rec = my_net(x)\n",
"loss = tf.reduce_mean(tf.squared_difference(x, x_rec))\n",
"\n",
"reg = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))\n",
"# or manually: reg = my_net.gate.get_penalty()\n",
"\n",
"step = tf.train.AdamOptimizer(learning_rate=1e-2).minimize(loss + reg)\n",
"\n",
"sess.run(tf.global_variables_initializer())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rate_of_nonzero_activations = 0.66\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEACAYAAABfxaZOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XmYFNXV+PHvGTYXkFVA9mVUFAVRgyhGx4iIcYtJTMBfEE1UIvK6ZYFoCEPiG7e4JBEwC3ELBpeYNyquUUY0EQUVUcKurIPsyDoMDOf3x61mepqe6eqe6q5ezud5+pnuqupbp3tun66+deteUVWMMcYUhqKwAzDGGJM5lvSNMaaAWNI3xpgCYknfGGMKiCV9Y4wpIJb0jTGmgCRM+iIyRUTWici8BNt9RUT2isg3gwvPmPoTkSEislBEFovImDjrR4rIPBH5SERmikgvb3lDEXnUWzdfRMZmPnpjguXnSP8R4Py6NhCRIuAu4NUggjImKF7dfAhXh3sDwyJJPcpUVe2jqv2Ae4EHvOWXA41VtQ9wKjBSRLpkKHRj0iJh0lfVd4AtCTb7H+BZYH0QQRkToP7AElVdoap7gWnApdEbqOqOqIdNgf2RVcDhItIAOAzYA2xLf8jGpE+92/RFpAPwDVWdDEj9QzImUB2BVVGPV3vLahCRUSKyFPeL9UZv8bPALmAtsBz4japuTWu0xqRZECdyHwSi20kt8Zuco6qTVLUYV5fHeYv7A/uA9kAP4Mci0i2UAI0JSMMAyjgVmCYiArQBLhCRvar6fOyGImID/Zi0UtXYg441QHQ7fCdvWW2eAiZ7968AXlHV/cAGEfk3rr4vj32S1W2TbnHqdkr8HukLtRzBq2oP79Yd93N4VLyEH7V9vW7jx4+3MvKwjCBiqMVsoFhEuopIY2AoUKN+ikhx1MOLgCXe/ZXA17xtDgcGAAutbmd/DPlWRpASHumLyJNACdBaRFYC44HGro7rH2PrfaDRGVNPqlolIqOB13AHOVNUdYGITABmq+qLwGgRGQRU4jotjPCePhF4REQ+9R5PUdVPMSaHJUz6qnqF38JU9fv1C8eY4KnqK8CxMcvGR92/uZbn7QS+k97ojMmsnLsit6SkxMrIwzKCiCHXZcP/IYgysiGGfCsjSBJ0e1GdOxPRTO7PFBYRQQM62ZXCvq1um7QJsm4H0XvHZEi3bt1YsWJF2GGErmvXrixfvjzsMEyaFHI9z0TdtiP9HOJ924cdRuhqex/sSD8/FHI9z0Tdzrk2fWOMMamzpG+MMQXEkr4xxhQQS/omEN27d+fNN98MOwxjTAKW9I0xpoBY0jfGmBRUVVWFHUJKLOmbQFVWVnLzzTfTsWNHOnXqxC233MLevXsB2LRpExdffDEtW7akdevWnH322Qeed/fdd9OpUyeOOOIIjjvuOGbMmBHWSzCmVt27d+eee+6hb9++NG3alFWrVvGtb32Ltm3b0rNnT37/+98f2LaiooIRI0bQqlUrevfuzb333kvnzp1DjN6xi7NMoO644w7ef/995s1zUypfcskl3HHHHUyYMIH77ruPzp07s2nTJlSVWbNmAbB48WImTpzIBx98QLt27Vi5cmXOHkWZ/Ddt2jRefvllWrZsyVe/+lUuu+wynn76aVatWsWgQYPo1asX5513HqWlpaxcuZLly5ezY8cOLrjgAtwI9OGyI/08IhLMrT6efPJJxo8fT+vWrWndujXjx4/niSeeAKBRo0asXbuWzz//nAYNGjBw4EAAGjRoQGVlJZ9++in79u2jS5cudO/evb5vh8lTYdfzm266iQ4dOjBv3jw2btzI7bffToMGDejWrRvXXHMN06ZNA+CZZ57h9ttv54gjjqBDhw7ceOONCUrODEv6eUQ1mFsqIlcSlpeX06VL9ZwlXbt2pby8HICf/OQn9OzZk8GDB1NcXMzdd98NQM+ePXnwwQcpLS2lXbt2XHHFFaxdu7be74fJT2HWc4BOnToBsHLlStasWUOrVq1o1aoVLVu25M4772T9ejdVeHl5+YFtgaxo2gFL+iZAIkLHjh1rjJuyYsUKOnToAEDTpk35zW9+w7Jly3j++ee5//77D7TdDx06lLfffvvAc8eOHZv5F2CMD5Emms6dO9OjRw82b97M5s2b2bJlC19++SUvvPACAB06dGD16tUHnrdy5cpQ4o1lSd8EIjJeyNChQ7njjjvYuHEjGzdu5Fe/+hXDhw8HYPr06SxbtgyAZs2a0bBhQ4qKili8eDEzZsygsrKSxo0bc+ihh1JUZFXTZLf+/fvTrFkz7rnnHioqKqiqqmL+/PnMmTMHgMsvv5w777yTrVu3smbNGiZOnBhyxI59skwgIkc/48aN45RTTqFPnz707duXU089ldtvvx2AJUuWMGjQIJo1a8bAgQO54YYbOPvss9mzZw9jx47lyCOPpEOHDmzYsIE777wzzJdjTFzRJ2KLiop48cUXmTt3Lt27d6dt27Zce+21bNu2DYBf/OIXdOzYke7duzN48GAuv/xymjRpElboB9gomzmkkEcfjGajbOa3fK3nDz/8ME899VSd3ZFtlE1jjMlRX3zxBf/5z39QVRYtWsR9993HN7/5zbDDsn76xhiTDpWVlYwcOZLly5fTokULhg0bxvXXXx92WNa8k0vy9Wdvsqx5J78Vcj235h1jjDGBsqRvjDEFJGHSF5EpIrJORObVsv4KEfnYu70jIicGH6Yxxpgg+DmR+wjwe+DxWtZ/Bpylql+KyBDgT8CAgOIzUbp27ZoVAzaFrWvXrmGHYNKokOt5Juq2rxO5ItIVeEFV+yTYrgXwiarGHWTCTnaZdLITuSZfZfOJ3GuAlwMu05i8du+9cM01YUdhCkVg/fRF5BzgauDMurYrLS09cL+kpISSkpKgQjAFpqysjLKysrDDqLdHHoEFC+D+++GII8KOxuS7QJp3RKQP8HdgiKouq6Mc+wls0qa2n8DeuaYHcb9sp6jq3THrRwI3AFXAduA6VV0oIlcAPwEUEKAP0E9VD+rUkGrd3rMHWrWCzp3hL3+BM85IughTAMJo3hHvFi+YLriEP7yuhG9MGESkCHgIOB/oDQwTkV4xm01V1T6q2g+4F3gAQFWfVNV+qnoyMBz4LF7Cr4/33oNjj4XTT4fzzw+yZGPi89Nl80ngP8AxIrJSRK4WkZEicp23yTigFTBJRD4SkffTGK8xyeoPLFHVFaq6F5gGXBq9garuiHrYFNgfp5xh3nMDUV7ujvJfegmGDIH77oMdO2DNmqD2YEx8Cdv0VfWKBOuvBa4NLCJjgtURWBX1eDXui6AGERkF3Ao0Ar4Wp5zvApcEFlRHd/J28WL4+c9dE89ll8Gzz8JNNwW1F2MOZlfkGgOo6iRVLQbG4H69HiAi/YGdqvrfIPa1b5/7O3UqzJwJ/b2voK9+FW6+uX5T+RmTiI2yafLdGqBL1ONO3rLaPAU8HLNsKPC3RDvy2zNt6FD3d/dud795c/f45pth4kSYNw/69k20N5PP0tkzzUbZNHkjXg8HEWkALALOBdYC7wPDVHVB1DbFqrrUu38xME5V+3uPBdc8dKaqLq9j377r9tlnu6P5Cy+EW26Bxo2r111zDZx8Mowa5asoUyCC7L1jR/omr6lqlYiMBl6jusvmAhGZAMxW1ReB0SIyCKgEtgAjooo4C1hZV8JP1urVMH069IrtQwSceirMnh3Unow5mB3pm7yRC8MwVFZCs2aup06jRgevLyuDcePg7beDj9HkrmwehsEYU4eVK13PnXgJH+Doo2HJkszGZAqLJX1jMmjFCqhrIMWjjoLt293NmHSwpG9MBpWXQ4cOta8vKoKePWHp0szFZAqLJX1jMmjt2rqTPkBxsbtoy5h0sKRvTAaVl7smnLoMHAivv56ZeEzhsaRvTAatWQOdOtW9zeDB8M47mYnHFB5L+sZk0OrViZN+9+6ul4/1bjbpYEnfmAzyk/SbNoVDD4UNGzITkykslvSNyZB9+2DdusRt+uC6da5Ykf6YTOGxpG9MhnzxBRx5ZO0XZkWzpG/SxZK+MRmyapW7GtcPS/omXSzpG5Mh8+fDccf527Z7d/jss/TGYwqTJX1jMuSTT6BPH3/b9u7tviSMCZolfWMyZPly6NHD37a9esHChWkNxxQoS/rGZMjy5dCtm79tO3aEnTth69Z0RmQKkSV9YzJAFT7/vO4RNqOJwDHHwKJF6Y3LFB5L+sZkwAcfQPv20KqV/+f06mVJ3wTPkr4xGfD++1DLPOm1OvZYS/omeAmTvohMEZF1IjKvjm1+JyJLRGSuiJwUbIjG5L4NG6Bdu+SeY0nfpIOfI/1HgPNrWykiFwA9VfVoYCTwcECxGZM3NmxwV+Mm47jj4L//TU88pnAlTPqq+g6wpY5NLgUe97Z9D2guIkke0xiT31JJ+r16uZO/FRXpickUpiDa9DsCq6Ier/GWGWM8qST9xo3dROl2kZYJUsNM77C0tPTA/ZKSEkqSPbtljKesrIyysrKww/AllaQPrl1/yRI45ZTgYzKFSdTHTA0i0hV4QVUPuohcRB4GZqjqU97jhcDZqrouzrbqZ3/GpEJEUFUJad911u327eHDDxPPjxtr7Fho1gxuv72eAZqcFmTd9tu8I94tnueBK73ABgBb4yX8iP37k4rPmJy3fz9s2gRt2iT/3J49Ydmy4GMyhSth846IPAmUAK1FZCUwHmgMqKr+UVVfEpGvi8hSYCdwdd3l1T9oY3LJpk1wxBGujT5ZPXrA1KnBx2QKV8Kkr6pX+NhmdDDhGJN/li9PPEVibWxcfRM0uyLXmDSbORMGDkztuZ07Q3k5VFUFG5MpXJb0jUmzOXNgwIDUntukiTsXUF4ebEymcGU86e/Zk+k9mkInIkNEZKGILBaRMXHWjxSReSLykYjMFJFeUev6iMh/RORTEflYRJJumd+40fXeSVXXrq6JyJggZDzpv/depvdoCpmIFAEP4YYS6Q0Mi07qnqmq2kdV+wH3Ag94z20APAFcp6on4Do07E02ho0boXXr1F+DteubIGU86Vs3fZNh/YElqrpCVfcC03BDhxygqjuiHjYFIh2LBwMfq+qn3nZbUrnQZOPG1LprRnTrZkf6JjiW9E2+ix0mZDVxhgkRkVFet+O7gBu9xcd4614RkTki8pNUAti0qX5H+t26uTF4jAlCxodh+PxzOOecTO/VmLqp6iRgkogMBcYBV+E+HwOBU4EK4A0RmaOqM+KVEW+IkZ073YHO4YenHlv37vD006k/3+SedA4x4msYhsB2JqInnaR89FHGdmkKSLxL1b2rxEtVdYj3eCzuwsK7aylDgC2q2kJEvgsMUdWrvXU/B3ar6n1xnhe35WfVKjj9dFi9OvXXtWQJnH8+fPZZ6mWY3BbGMAyBseYdk2GzgWIR6er1vBmKGzrkABEpjnp4EbDYu/8qcKKIHCIiDYGzgaRGuN+6FVq2TDl2wJ3IXbMG9u2rXznGQAjNO5b0TSapapWIjAZewx3kTFHVBSIyAZitqi8Co0VkEFCJmztihPfcrSJyPzAHd3J3uqq+nMz+t2yBFi3q9xoaN3ZdPletck09xtSHJX2T91T1FeDYmGXjo+7fXMdznwSeTHXfW7bU/0gf3Bg8y5ZZ0jf1l/HmHRtl0xSSrVvrf6QPUFzs2vaNqS9r0zcmjYI60j/+eJsv1wQj40nfBo4yhSSopN+7t02baIKR8aS/aFGm92hMeFKdJjHW8cfDggX1L8cYG2XTmDRavz6YpN+hgzs/sHNn/csyhc2SvjFptGEDtG1b/3KKimw4BhMMS/rGpFFQzTvg5su1q3JNfVnSNyaN6jvYWrRIX31j6sOSvjFptG2bmxQ9CD162JG+qT9L+sakyb59sHcvHHpoMOVZ844JgiV9Y9Jk+3Zo2hQkkLERrXnHBMOSvjFpsn07NGsWXHldu8LKlXZVu6kfX0nfx8TSnUXkTRH5UETmisgFwYdqTG4JOuk3bQqHHeb6/huTqoRJ3+fE0j8HnlLVk4FhwKSgAzUm1wSd9MGNsmlNPKY+/BzpJ5xYGjfWeKSPQgtgTXAhGpOb0pH0e/e2gddM/fgZTz/exNL9Y7aZALwmIjcChwGDggnPmNyVjqR/wgk28Jqpn6AmURkGPKKqD3hzkv4V1xQURymR+aMjk0cbk4p0Th4dhHQd6b/2WrBlmsKScGJ0PxNLi8inwPmqusZ7vAw4TVU3xpSloNb7wKRFkJNHp7DvgyZGf+ghNzLmxInB7eezz+Ccc2DFiuDKNNkv0xOjJ5xYGliB16QjIscBTWITvjGFJh1H+l26wLp1sHt3sOWawpEw6atqFRCZWHo+MC0ysbSIXORt9mPgWhGZC0zFm1jamEKWjqTfsKEbbdN68JhU+WrT9zGx9ALgzGBDMya3bd8O7dsHX+4xx7j5ck84IfiyTf6zK3KNSZN0HOkD9OoFn34afLmmMFjSNyZN0pX0Tz8d/vOf4Ms1hcGSvjFpkq6kP3AgvPsuVFUFX7bJf5b0jUmTdCX9tm3dzS7SMqmwpG9MmqQr6QOceSa88056yjb5zZK+MWmSzqR/2mnwwQfpKdvkN0v6xqRJOpP+iSfCJ5+kp2yT3yzpm7znYz6IkSIyT0Q+EpGZkaHDvavQd3nzRHwoIr6HDFdNb9I/4QQ32ub+/ekp3+SvhGPvBLozG3vHpFG88Um8+SAWA+cC5bhhRYaq6sKobZqq6g7v/sXAKFW9QES6Ai+oah8f+64x9k5FBTRvDnv2BPHK4uvaFd54A4qL07cPkx0yPfaOMbks4XwQkYTvaYqbHyIipQ9aOo/yI6yJx6TCkr7Jd/Hmg+gYu5GIjBKRpcBdwI1Rq7qJyAciMkNEfA81komkf9JJ8OGH6d2HyT9BjaeflKuugkcfDWPPxsSnqpOASSIyFBgHXAWsBbqo6hYRORn4PxE5PuaXwQGlkYkigE6dSmjWrCStMffvD5NsYtK8lM65IkJp0wesXd8ErpY2/YTzQcRsL8AWVW0RZ90M4EeqetDxdWyb/jvvwJgx8O9/1+sl1Wn1ajj5ZDfUsoQyi4DJFGvTN8a/hPNBiEj0qdCLcCd+EZE23olgRKQHUAx85menmWje6djRHTytXZve/Zj8EkrzjjGZoqpVIhKZD6IImBKZDwKYraovAqNFZBBQCWyhej6Is4Bfikgl7uTuSFXd6me/mUj6Iq5df+5c6NAhvfsy+cOSvsl7PuaDuLmW5z0HPJfKPjOR9AH69XNJ/+tfT/++TH4IrXln06a617/zDnzve5mJxZigZSrpn3QSfPRR+vdj8kdoSb9NG/d31Kj4kzxPnepuxuSiTCb9uXPTvx+TP0I/kTt5MpxxRthRGBOsTCX9Y4+F8nK3P2P8CDXpP/WU+1teHmYUxgQvU0m/QQM3Ds/HH6d/XyY/hJr0hw4Nc+/GpE+mkj5Un8w1xo/Qm3eMyUeZTPrWrm+SYUnfmDTIdNK3HjzGL19JP9F45N423xGR+SLyiYj8tb6B2TANJpdlMun36QMLFsDevZnZn8ltCS/O8i5Df4io8chF5J8x45EXA2OA01V1m4i0SVfAxuSCbdsyl/QPO8yNrb9ggfsCMKYufo70E45HDlwLTFTVbQCqujGVYObPT+VZxmSfTB7pgzuZaz14jB9+kr6f8ciPAY4VkXdE5D8icn6ygVRUuK5nxuSDTCd9G1vf+BXU2DsNcSMQngV0AWaKyAmRI/+aSqPul3g3eP31gCIxBSOdY47Xx/79sHMnNG2auX2edhqMHZu5/ZnclXA8fT/jkYvIZGCWqj7mPf4XMEZVP4gp68B4+rVRhR07YPRoeOwxO6Fr/AtyzPEU9n1gPP3t26F9e5f4M2XXLmjb1o2tf/jhmduvyYxMj6efcDxy4P+Ac7zg2gBH43Pc8Xi6dHEJ35hclOmmHXAnc08+Gd5+O7P7NbknYdJX1SogMh75fGBaZDxyEbnI2+ZVYJOIzAfeAH6sqltSDWpLys80JnxhJH2Ac8+FN97I/H5NbgltusTaqNac+s2ad4xf2dK8M2cOXHdd5k+svvMO3HijndDNR3k9XeKOmCmnLembXBPWkX7//rB0aeK5Kkxhy7qk/+Mf13z8y1+GE4cxqQor6TduDGeeCTNmZH7fJndkXdLfGjMDqf1UNbkmrKQP1q5vEsu6pL9/f83H1rxjck3YSf/NN8PZt8kNWZf0n3mm5mNL+ibXhJn0+/SBzZth1arE25rClHVJ35hcF2bSLyqCIUPghRfC2b/Jflmf9O1I3+SaMJM+wGWXwXPPhbd/k92yPukbk2vCTvpDhrhrBdavDy8Gk72yPunbkb7JNWEn/cMOgwsvhGefDS8Gk72yPukbk2vCTvoAQ4fCtGnhxmCyU9Yn/X//O+wIjElONiT9wYPdTFoLFoQbh8k+WZ/0v/wy7AiMSU42JP0mTeCnP4Wf/SzcOEz2yfqkb0x9icgQEVkoIotFZEyc9SNFZJ6IfCQiM0WkV8z6LiKyXURu9bO/bEj6AP/zP24K0hdfDDsSk00s6Zu8JiJFwEPA+UBvYFhsUgemqmofVe0H3As8ELP+PuAlv/vMlqR/yCEwaZJL/rt2hR2NyRY5kfT37Ak7ApPD+gNLVHWFqu4FpgGXRm+gqtFjuzYFDgwGIiKX4iYEmu93h9mS9AHOOw8GDLCBC021nEj669fDI4/AK6+EHYnJQR2B6EEJVnvLahCRUSKyFLgLuNFbdjjwU2AC4Gss8337oLLSdZvMFg88AE88YVfpGieoidHTqkuX6vvWb9+kg6pOAiaJyFBgHHAVUAo8oKq7xM3sU2fiLy0tZfduaNgQ3nqrhJKSkvQG7VP79u4K3YsuciNw9ukTdkQmkbKyMsrKytJSdtbNnJXIiy+6C0+MiRVvdiERGQCUquoQ7/FYQFX17lrKEGCzqrYUkZlAJ29VS6AK+IX3BRH7PFVVVq2C00+H1asDfGEB+dvfXG+e996Ddu3CjsYkI69nzkrkvffCjsDkmNlAsYh0FZHGwFDg+egNRKQ46uFFwBIAVT1LVXuoag/gQeDX8RJ+tF27sqtpJ9qwYTBihBubZ9u2sKMxYcm5pG9MMlS1ChgNvIY7GTtNVReIyAQRucjbbLSIfCoiHwI3AyNS3V82J32A8eOhXz/4yldg4cKwozFhyIk2/WgbN7q/s2a5C7fOPz/ceEz2U9VXgGNjlo2Pun+zjzIm+NlXtif9oiKYOBH+/Gc3teKYMXDrrdCgQdiRmUzJuSP9v//d/b3sMjeaoDHZZPduOPTQsKNI7Jpr4P334aWX4IwzYO7csCMymZJzST/iiy/CjsCYg2X7kX60Hj1cb57rrnNj9UyebL3jCoGvpJ/oMvao7b4lIvtF5OTgQqxp9264447qx9OmQUVFuvZmTHJyKemDa+75wQ/g3Xdd0v/2t91nzOSvhEnf52XsiEhT3EUts+oqr2/f1AKN2L4dxo2rfjxsGLz6av3KNCYoudK8E6tnT5g9211jcPnlUFUVdkQmXfwc6Se8jN3zK9zVjHUOmtCjR9IxGpMzcu1IP1qTJvDXv8LOnW68np07w47IpIOfpJ/wMnYR6Qd0UtWXExWWjjbDSPv+v/9tswWZcO3enbtJH6BRI3j6adiyBY4+Gj79NOyITNDq3WXTu4Lxfmr2ba71yrG9e0ujHpV4t/r54Q+hc2e46SZYutRORhWKdF6qnqpdu3KzeSfakUe6q3effBIuvtj17GnePOyoTFASDsOQ6DJ2ETkCWArswCX79sAm4BJV/TCmLFVVJJCLiQ9WXGxJv5AFeal6CvtWVeVnP3MjbN52WxhRBO+669wvlwcfDDuSwpbpYRjqvIxdVbepalvvcvXuuBO5F8cm/ExI15eJMX7l6onc2vzyl/DYY9UXRZrclzDp+7yMvcZT8DkMbdDsCN+EraLCTV6SL9q3h0sucUObm/yQ8VE2o5t3fvc7uPHG4Pdjyb8wZUPzzogRcM45cNVVYUSRHrNnw3e+45pObbiGcOTFKJs//anrFmZMPsm3I31wg7O1bg1vvhl2JCYIoSX9Xt7lXe3bhxWBMcHLtzb9iBEj4PHHw47CBCGUpL9vX/XP35Ejw4jAmPSoqMjPpD90qJtucfv2sCMx9RVK0m/QoLqnTWlp8OW//TZ8/nnw5RqTyO7d+de8A67v/tlnV49ya3JXVo2yOavOUXv8O+ssuPLKYMoyJhn52rwD7jP1xBNhR2HqK2uS/gUXwGmnBXeUtGULHHMMrFgRTHnG+LFjBzRtGnYU6XHRRe7q3JUrw47E1EdWJP233nIDPQEcdVQwZc6fD0uWQLduwZRnjB/5nPSbNHEjcE6dGnYkpj6yIumfdRa0alVz2cyZ4cRiTH3s2OGGYchXV17pevHYtTC5KyuSfjxf/WrYERiTHFXXu+Xww8OOJH1OPx327oUPPgg7EpOqrE36Qbr1VnjvPZthy6RXRYUbmrhRo7AjSR8RGD7c+uznsqxL+uno+fCnP8GAATBpUvBlGxOxY0d+H+VHDB/upimtrAw7EpOKrEv6ffqkr2yrpCaddu7M35O40Xr0gGOPhVdeCTsSk4qsS/rRfvObYMrZsSOYcoypy86dhXGkD9UndE3uyeqk/6MfBVtedI+D55+34WJNsAop6V9+Obz+OqxfH3YkJllZnfSDdvvtUFXl7l97LXz/++HGY/JLISX9Fi3c0f7dd4cdiUlW1iX9X/3Kzc2ZDqowfTp873vpKd8Utp07c3tS9GTddhs8+iisWRN2JCYZWZf0i4th2LDqx8OHw6efBlf+3/7mrijcvLnm8uXL3b6NSVUhHemDu3r+Bz+A//3fsCMxyci6pB/r8cehd293P8jZiPbtc39vuMH1PZ49G5Ytgz17YMECt2zTJjecg8ltIjJERBaKyGIRGRNn/UgRmSciH4nITBHp5S3/ircscvtGXfvZtauwkj64yZCefho+zPiM2CZVWZ/0gzZtWs3Hkb773/mO+3viiXD88e7+iy/CCSckLvOmm2DduuBiNMERkSLgIeB8oDcwLJLUo0xV1T6q2g+4F3jAW/4JcIq3/ALgD155cRXakT5AmzYweTJcdpnS84+wAAAUoElEQVQ18+SKnEr6mZhla8mS6vuRk77Llrm/s2bBhRce/Jzf/c71ZMgGV17pPoDmgP7AElVdoap7gWnApdEbqGp0p96mwH5veYWq7veWHxpZXptCTPrgevKMGuWGTikrCzsak0jDsAPwq7zczdPZtStcf31m9hnp4jl/PvTsCc88Ay+9BPv3Q1FR/G2DsncvbNvmXnMynn7aNVHlq8pK9143aeL7KR2BVVGPV+O+CGoQkVHArUAj4GtRy/sDfwG6AMOjvgQOUmgncqONGeOmQL36avfrePx4OPXUsKMy8eTMkf5RR0Hjxpkd3S+yr0svdb2K7r/fPW7Q4OBt90elgscfh40bU9vnggXw3/+6nhFt2rhlP/pRMANc7dqV+nNXr4ZBg+ofQ32dd156komqTlLVYmAMMC5q+fuqegLwFeA2EWlcWxmFeqQfcemlsHCh+x9985vQt6/rJv3KK+6gzUbmzA45c6QfEak43btndkrEX/yi5uOtW11f5di4wE0iPXy4Ozm8YEH88n74QzekdMeOUFLiehO1bOnOKVRVVZ9jAPdls2sXnHKK/3hnznRzCXTp4h6//joMHhz/g6fq9tmwjtowaxa88Yb//afL++8nPXDeGtxRekQnb1ltngIejl2oqotEZAdwAhD3tOVbb5XSrp1L/iUlJZSUlCQVaD5o0gRuvNF1kJg1y/0yvuce+OQTN6tYcXH1rVMnNwx106bub7z6V1XlOl00aFC9XrX6tn9/3Y8jy/bvd2XF3vbt87csmW2DeP7OnWXs2lV2IP4g+Ur6IjIEeBD3y2CKqt4ds/4W4BpgL7AB+L6qrjqooABE/vH9+qU/6f/2t7Wva9nSdVf785/d49hk+u67sHRp9ePIP2/bNjcsxB/+4G4R27a5MiP/4Njyov/xkV4ikQ9DPGef7Y64XnvNPS4vr/21PPMMfPe7dR+Jxa77+9/duYPYZq50S+FocTZQLCJdgbXAUGBY9AYiUqyqkf/WRcBib3k3YJWqVnnPPxZYXtuOjj++lHPPdV/6ha5BAxg40N0itm5158eWLnW3BQvcZ2H7dneLnEOLVlTkRi2tqnJNnuB61om4dZH78R5HLysqcjFF3xo2PHhZvOWNGrmBIP0832+ZiZeVeDf3+KijJgT2v0mY9KN6P5wLlAOzReSfqrowarMPcb0cKkTkh7geEEMDizLKiBFw3HF1J+SgJLo+YMoUd7QCLim/+251z5/oJH3RRe5o5qmnXDJ+662DyxoyxE0XGUlqzzxTc310eZGuq/v31570oWaCrGu7//639nX/+Ad84xsHJ9tvfxsWL4ajj679uX5NmQLXXOMvoSeb9L2EPRp4jeqDlgUiMgGYraovAqNFZBBQCWwBImn7TGCsiFTiTuJer6qbD96LU+jNO4m0aOF+rSbzi9WkgarWeQMGAC9HPR4LjKlj+5OAt2tZp0H51rdif8SFexsxIv5yVf9liBy8bP9+9/f733d/N2+uXvfb36p27VrzfWnSpHr9uedWL586tTqeWNGxxls3darqtGk1twHVRYtq//+88ILqv/6V6L/oXH+9K6+ysnrZ1q2qW7aozp1bc9uGDd22y5erbtsWGyuqCepzum6Ann++6vTp/l6zMckIsm77+XEer/dDxzq2/wHwclLfPCn42tcSb5NJjz0Wf7mI/zI0zlHsUO/3UuRIP3payTfecBO/q1bPMRzdcye6vLqO9KNjnTrVXZ0c/cvi//2/6maiaLE/x/fvr77S+eKL3cm8WLfdVv0zPSJyoVzjxvDZZ+5+ixauueukk+Jv262ba17LJnakb3JBoC2yIvI94BRc805cpaWlB25l9ejUO2pU9f18nhzl6afd33hNMNHt6cOHH7w+kvRVYeTImuu2bXPtqJMn11z+1lvuJHns+Ed/+cvB5e/effA20V1Mt22ruf6RR+DOO+Hhh6uTN7hJbiJ69jx4P7V55pkyxo4tpUePUsaNK/X/xDTZvh2OOCLsKIxJINFPAVzzzitRj+M27wCDgPlA6zrKCvgnj7tNnhxcM00u3Y48Ums0AT3zTM31J57o3qfp06uXzZih+thjqh06qPbqpdqpU83nXHON+ztxourevQfvU1V1376ajyPuuMMte+KJ6vXDhrl1sc1xL7ygumdPzf9j5NavX83HGzce/D+P3BYscH/Xr9fQm3e6d1ddurR+ddqYeIKs234qcwNgKdAVaAzMBY6L2aaft03PBGUF+ka0bKnaqpXqpEmZSbLZenvppfjLDztMdfx41QsuiL++qKj2Mk8+Of7yDRtqPi4rc4nu17+ubpuPvX32Wfzlp58eqdCJb9WVP/7tn//U0JN+mzaq69YFWsWNUdVg67a48urmddn8LdW9H+6K7v0gIq/j+i+vBQRYoaoHDU4lIupnf35FuiGuWwcnnxxYsaYO554bbH/9Dh3q7k4asX69O29x6611bSWoahJnUYIjItq4sfLll3DIIWFEYPKZSHB121fSD0rQST/aRx9Z4jfhJv2GDZXKyuRO3hvjhyX9WlRUuIsoTKEKN+k3b65s3RrG3k2+CzLp58zYO34ccggMGBB2FKZQ1TWMhTHZIq+O9CNUMz88gMkG4R7pH3mk2kThJi3sSD+BSJtqdF9+Y9LNDjRMLsjbatq3L4we7SZ1GJqWUYCMqclO4JpckLetkHPnur/HHecu5a+shOeeCzcmk9+CHgLXmHTI2yP9aM2bu6GAFy2quby2oZmjhwgwxq94QwMbk20KIulHxJ5D7tYNPoyZDmPOHDc42T/+Ub3s9tv97yN68hNTWOINmW1MtsnL3ju1qapyg35dcolrfz3yyOjY3N/o8CLL3njDXYn6/PPuJ/w3DrrW2Hn0Uff8q69OS/gmoXB774RZt01+s947KWrQwE3W0bZtzYQPbnKTqVNrLrvuOvc30txz8cXQ35tSu7j44PJHjEjtSH/JkrrXjxlTcyTKTBo3Dt58M/376dAh/fswxhRY0q9LWRlccUXNZe3aub/RXxBNm7q/M2fCvHlu+OA5c6pn8jrsMPj444PL/93voH1790vg44/dr4d9+9zj2OF4740amFoV7roLLrwwfty33FJ9/ze/cXORQu3dVceNi7884uabq++vWgW//CWcc07tc/3WJXrs/0QiTSPvvlv3dqtWwbXXJh+LMcYT1Mhtfm4EPMpmum3frvr+++7+7t3JPXfTJtW1a1WnTPG3/axZqn/+sxsx8te/rjmyZDRQPfVU93fmTLds376as05FZpQC1RNOcH9vvrm6zN27Ve+/v+YolYsXu79796p+8EH8fb/7ruqjj7rt1q1zQyOfe27NoZRLS1XLy1Xfflt1507V995zyyMjoa5Y4e4/9VT8UTRVVc84Q/XVV6uX33ij6mmnuRhVVUeO1AOjiNYcbTPcUTaNSZcg67Z9MLLQokWqN9wQfx2ojh6t+pe/JC5nxQo3zvz06W7M/WgPPFCdpF9/3X9sa9fG/0KKJN6FC+Ov27kzfnl//avqhAkHL58xo/YvvrfeUj3nHNU1a9w2zz2n2qaNJX2Tv4Ks2wV1IjcfVFRAo0b+pj+sy6pV8OCDcN99wcQl4pqB4rX/f/ml6zabjNWr3aTrsbNz1WbPHjjkEDuRa/KTjbJpTBxBfjBS2LfVbZM21nvHGGNMSizpG2NMAbGkb4wxBcSSvjHGFBBL+sYYU0As6RtjTAGxpG+MMQXEV9IXkSEislBEFovImDjrG4vINBFZIiLvikiX4EM1JjU+6u9IEZknIh+JyEwR6eUtHyQic0TkYxGZLSLnZD56Y4KVMOmLSBHwEHA+0BsYFvlQRPkBsFlVjwYeBO4JOtCIsrIyKyMPywgihnh81t+pqtpHVfsB9wIPeMs3ABepal/gKuCJtATpyYb/QxBlZEMM+VZGkPwc6fcHlqjqClXdC0wDLo3Z5lLgMe/+s8C5wYVYU7b8E6yMYMtI4wcjYf1V1R1RD5sC+73lH6vqF979+cAhItIoXYFmw/8hiDKyIYZ8KyNIfubI7Qisinq8GvdBiruNqlaJyFYRaaWqm4MJ05iU+am/iMgo4FagEfC1OOu/DXzofXEYk7PSdSI3lPFPjEmVqk5S1WJgDFBj1gER6Q3cCVwXRmzGBCrRMJzAAOCVqMdjgTEx27wMnObdbwCsr6UstZvd0nlLpf7GbC/A1qjHnYBFwIAEn5PQX7vd8vsW1NDKfpp3ZgPFItIVWAsMBYbFbPMCMAJ4D7gciDvBXlgjIJqClrD+ikixqi71Hl4ELPaWtwBexH1JzKprJ1a3Ta5ImPS9NvrRwGu45qApqrpARCYAs1X1RWAK8ISILAE24T5YxoTOZ/0dLSKDgEpgC+4ABuAGoCfwCxEZjzviGqyqGzP+QowJSEbH0zfGGBOyoNqJfJwbGAIsxP10jj0nMAVYB8yLWtYSd3S2CHgVaB617nfAEmAucBKu3fVNXM+MPV5ZVyZZRhNc89QKr4yNXhndgFle3H8DGnrPb4zr/rcEeBfoElX2bUAFsB0YnGwZwHJcU0QFsNsrw/dr8ZY1B973XsseoDTJ9+MY4CMvlgqgCtdPPdk4bon6v2wDrk7m/QBuAj4BvgDWAwuAbyYTg7d8hLe/RcCVuVCvo84p/Dfq//hECp+PQOo27pfSGmCH938YnszzvXVbvBgqgIUpvJbmwDNAuVfOCuCHSTw/K+q1t/wmXL3e4/1N+nOeSt3OVMIvApYCXXFd4uYCvaLWn+n9Q6I/HHcDP/XujwHu8u5fAEz37p/mvcntga8Cy4AO3huzEneRja8yvPsdvDJa4tqCVwPPAZd76ycDI7371wOTvPvfBaZ594/3njcVeN173U8lWcYqYB6u+a2bV4bv98O7/yQuSTYHWgOfp/B+tPTejxa4L6HlyZThvZ/LvTKae+/luiTe05e896Ev7oP6OnAWLnGk8jqae69lGVEfpmyt1979Y7360BzX/bQSODWFcoKo23cBm3Hn8LrhDmqS/WzsAdpQXa8lmdcCPIprdlsGtMIdGGwBfpFD9Xoa7kLBRbg60xiY6ZWZ7P816bqdqaQ/AHg56nG8HkBdqfnhWAi08+63BxZ49x8Gvhu13QKgHe48wmRv2f8Bz+OOSpIuAzgMmAP8HfgSKIp9HcArxOmxBPwa96VT4sXwMq5SJlPGZmBCVHwvexXC72vp6VXCyVHLJ9fj/RgMvJ1CGX1wV7VOwX2BPY9LGH7f0y+BP+HVF+DnwE9wR5oXpFI3ot6LA9vlSr32ln8O/CrVckixbuN+cSwB/gA8762rxOvVlOj5Ue/PZqB1VL0+LYnXsgj3OYh9T7YA1+VQvV4PfBv3S3yMt/zn3vvweSr/12TqdqYGXIt3gUzHBM9pq6rrANRdFdkuQVkdgVUi0g13dDUXaJlkGZ1wV2t+gTuqXAjsUdX9ceKucUEa8KWItML1XvoL7qQfuJ/Su5MsowEw3Bvv5RpchWyTxGs5FffTtb+IfCgif8RVtGTfj8jy7+J+OaxJsoyGwDvA97znfok7cvf7nm4BzgZ64L7Evg509so9NIXXEbGGxPXPj4zVawCvbrcAduISQ8bqNm4oizm4/wMi0hrX9NjBz/O9et0R2Au8KiKzcVc/d0zitWzG/bq4FbhERP4oIocBh3hl5Uq9/tJ7XAxs9l7D13FNTcl8zlOq27k0yqbWsjy6q1xj3DAQN+GOQpItQ3GVuxPuqs0jk4hPgPNwH4QVpHaBWuQ5z+K++b+O+ynbjoNjr+u1NMS9hlmqejIuSZwVZ1s/72kRcAmuHTWeuspoCpyAG4upA3A4cHQt28ezD9eOeSnwU1wTT1WSMWQ7X7GLSFNcvZiOq9vJ1IfI+lTr9uG4I9uNsXH5FP2cCap6Kq5uH+fd/L6WIqAX8G9gIq5ej03i+dlSrwX3q+lD4Ge4ZsyP8Ib/SCKOlGQq6a/Btb1FdPKW1WWdiLQDEJH2uCPVSFmd45S1Fvg+7kTXP73lm5MsYw3uJMs2oAw4EWjiDdoVG/eBMkSkAXAEru25C/B73Imcr+HGITrEbxnqhq5YAnRW1Q24pqqjgY1JvJa5uA9oA2/Z33FHy6m8HwOBD9R1U0z2PT0ad27lSO8I5x+4dl3f76mqPoRrb30M2Ir7ib8P9+Wa1P81zvL6ykS9XoNrInoWd8Jxh7cslXJSrdvNcQPWXYE7oPoa8FvcEXa5j+dH6vUaoBmAV7e3AG2TeC2tvP29h3vf/w6cjDtPsCPJ9yLser0Z+BfwB1UtwdXtIpL7nKdUtzOV9A9cICMijXHtUM/HbCPU/PZ6HjeyId7ff0YtvxJARAbgrp5ch/vWPgJ4VERa4o66n/VbBu4I8l3gPO9NvwB3sukNXJMNuLPk0WWM8O5fDrypqrcBZ+AqxPdwJ3124X5O+yrD+6n3L2Cod3HQxbgPytNJvB/zcSfJvi4izYELvffG9/vhvaev4k5G/l+K7+knuHbJwV4cF+B+tfh+T0XkSG/5cOAy3FHeXly7ajKv4zwRaR71Ol6l/jJRr1/FJdulwONRsfsuh/rX7RdUtQuus8Q6YAZuqIpK3BdSoudHLtZ8DbjCG4r9eFyCmp7Ea9mEa9Nf6r0PX/fuAxzl573Ilnrt3X8b9znvCXwH9yWazOc8tbqdygmsVG64rm2LcEexY2PWPUl1F6yVuO5PLXHJbxGusrSI2v4h3D/7Y9w3/UBcxV6Ja8uuwA2R2yqJMk7E/dxaSc1ubd1xRxaLcb1wGnnPb+L9g5bgknu3qLJ/hmtzi3TZ9F2Gt+3cqPdjvVeG79fiLesLfOaVsR0YmUIZh+G6oy31Yr8yhTLGR72WL73/bTLvx0zgU6+MNbgTWN9KJgZv+VVeuYsJvstmWuq1t2wg7md/pJvjcm+fYdXtP+KaVSJdNn0/39tf5P2owF0oR5KvpS/uy3Yl7uh+Ka53TE7Va2/5TKq7bK4ghc95KnXbLs4yxpgCkksnco0xxtSTJX1jjCkglvSNMaaAWNI3xpgCYknfGGMKiCV9Y4wpIJb0jTGmgFjSN8aYAvL/AcIkM583vnieAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc3bc7d7a20>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10000/10000 [00:10<00:00, 913.38it/s]\n"
]
}
],
"source": [
"from tqdm import trange\n",
"from IPython.display import clear_output\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"loss_history, reg_history = [], []\n",
"\n",
"for t in trange(10000):\n",
" loss_t, reg_t, _ = sess.run([loss, reg, step])\n",
" loss_history.append(loss_t)\n",
" reg_history.append(reg_t)\n",
" if t % 1000 == 0:\n",
" clear_output(True)\n",
" num_nonzero = sess.run(my_net.gate.get_sparsity_rate())\n",
" print(\"rate_of_nonzero_activations =\", num_nonzero)\n",
" \n",
" plt.subplot(1,2,1)\n",
" plt.plot(loss_history, label='loss')\n",
" plt.legend()\n",
" plt.subplot(1,2,2)\n",
" plt.plot(reg_history, label='reg')\n",
" plt.legend()\n",
" plt.show()\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1., 0., 0., 1., 1., 1., 0., 1., 0., 1., 0., 1., 1., 1., 0., 1.,\n",
" 1., 0., 1., 1., 0., 1., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0.,\n",
" 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 0., 0.,\n",
" 1., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1.,\n",
" 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 0., 0.,\n",
" 0., 1., 1., 0.]], dtype=float32)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# gate values\n",
"sess.run(my_net.gate.get_gates(is_train=False))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.02172022, -0. , 0. , 0.05970415, -0.03229865,\n",
" -0.01130804, 0. , -0.02124784, -0. , 0.02389894,\n",
" -0. , 0.01233097, 0.00665306, -0.00404538, -0. ,\n",
" 0.00455578, 0.00603381, 0. , 0.02199194, -0.03769897,\n",
" -0. , -0.00714258, 0.05062252, -0. , -0.02904983,\n",
" 0. , 0.049015 , -0.00445654, 0. , -0. ,\n",
" -0.01899791, -0. , 0.01498988, -0.00361649, -0.01428505,\n",
" 0.0200627 , 0.00454252, -0.02961647, 0. , -0.04975929,\n",
" -0. , 0.00315028, -0.03230491, 0. , -0.03524412,\n",
" 0.01989217, 0. , 0. , 0.00102865, -0.01898451,\n",
" -0. , -0.04948218, -0.03765366, 0.00472582, -0. ,\n",
" -0. , -0. , 0. , 0.01454964, -0.03137352,\n",
" -0.00582995, -0.04133054, -0. , -0.01280909, 0.01734644,\n",
" -0.02355555, 0.04760959, 0.02267313, 0.00704623, -0.05425708,\n",
" 0. , 0.01986646, -0.02759497, -0.05245164, -0.03126373,\n",
" -0.00284979, 0.02161182, 0.02729553, -0.04157891, -0.04360487,\n",
" -0. , -0. , 0.0135207 , -0.00187147, 0.0609222 ,\n",
" -0.07456608, 0.02430784, 0.00149097, 0.02840535, 0. ,\n",
" 0.07865206, -0.04211108, -0. , 0. , 0. ,\n",
" -0. , -0. , -0.05287053, 0.0263629 , -0. ],\n",
" [ 0.02172022, -0. , 0. , 0.05970415, -0.03229865,\n",
" -0.01130804, 0. , -0.02124784, -0. , 0.02389894,\n",
" -0. , 0.01233097, 0.00665306, -0.00404538, -0. ,\n",
" 0.00455578, 0.00603381, 0. , 0.02199194, -0.03769897,\n",
" -0. , -0.00714258, 0.05062252, -0. , -0.02904983,\n",
" 0. , 0.049015 , -0.00445654, 0. , -0. ,\n",
" -0.01899791, -0. , 0.01498988, -0.00361649, -0.01428505,\n",
" 0.0200627 , 0.00454252, -0.02961647, 0. , -0.04975929,\n",
" -0. , 0.00315028, -0.03230491, 0. , -0.03524412,\n",
" 0.01989217, 0. , 0. , 0.00102865, -0.01898451,\n",
" -0. , -0.04948218, -0.03765366, 0.00472582, -0. ,\n",
" -0. , -0. , 0. , 0.01454964, -0.03137352,\n",
" -0.00582995, -0.04133054, -0. , -0.01280909, 0.01734644,\n",
" -0.02355555, 0.04760959, 0.02267313, 0.00704623, -0.05425708,\n",
" 0. , 0.01986646, -0.02759497, -0.05245164, -0.03126373,\n",
" -0.00284979, 0.02161182, 0.02729553, -0.04157891, -0.04360487,\n",
" -0. , -0. , 0.0135207 , -0.00187147, 0.0609222 ,\n",
" -0.07456608, 0.02430784, 0.00149097, 0.02840535, 0. ,\n",
" 0.07865206, -0.04211108, -0. , 0. , 0. ,\n",
" -0. , -0. , -0.05287053, 0.0263629 , -0. ],\n",
" [ 0.02172022, -0. , 0. , 0.05970415, -0.03229865,\n",
" -0.01130804, 0. , -0.02124784, -0. , 0.02389894,\n",
" -0. , 0.01233097, 0.00665306, -0.00404538, -0. ,\n",
" 0.00455578, 0.00603381, 0. , 0.02199194, -0.03769897,\n",
" -0. , -0.00714258, 0.05062252, -0. , -0.02904983,\n",
" 0. , 0.049015 , -0.00445654, 0. , -0. ,\n",
" -0.01899791, -0. , 0.01498988, -0.00361649, -0.01428505,\n",
" 0.0200627 , 0.00454252, -0.02961647, 0. , -0.04975929,\n",
" -0. , 0.00315028, -0.03230491, 0. , -0.03524412,\n",
" 0.01989217, 0. , 0. , 0.00102865, -0.01898451,\n",
" -0. , -0.04948218, -0.03765366, 0.00472582, -0. ,\n",
" -0. , -0. , 0. , 0.01454964, -0.03137352,\n",
" -0.00582995, -0.04133054, -0. , -0.01280909, 0.01734644,\n",
" -0.02355555, 0.04760959, 0.02267313, 0.00704623, -0.05425708,\n",
" 0. , 0.01986646, -0.02759497, -0.05245164, -0.03126373,\n",
" -0.00284979, 0.02161182, 0.02729553, -0.04157891, -0.04360487,\n",
" -0. , -0. , 0.0135207 , -0.00187147, 0.0609222 ,\n",
" -0.07456608, 0.02430784, 0.00149097, 0.02840535, 0. ,\n",
" 0.07865206, -0.04211108, -0. , 0. , 0. ,\n",
" -0. , -0. , -0.05287053, 0.0263629 , -0. ],\n",
" [ 0.02172022, -0. , 0. , 0.05970415, -0.03229865,\n",
" -0.01130804, 0. , -0.02124784, -0. , 0.02389894,\n",
" -0. , 0.01233097, 0.00665306, -0.00404538, -0. ,\n",
" 0.00455578, 0.00603381, 0. , 0.02199194, -0.03769897,\n",
" -0. , -0.00714258, 0.05062252, -0. , -0.02904983,\n",
" 0. , 0.049015 , -0.00445654, 0. , -0. ,\n",
" -0.01899791, -0. , 0.01498988, -0.00361649, -0.01428505,\n",
" 0.0200627 , 0.00454252, -0.02961647, 0. , -0.04975929,\n",
" -0. , 0.00315028, -0.03230491, 0. , -0.03524412,\n",
" 0.01989217, 0. , 0. , 0.00102865, -0.01898451,\n",
" -0. , -0.04948218, -0.03765366, 0.00472582, -0. ,\n",
" -0. , -0. , 0. , 0.01454964, -0.03137352,\n",
" -0.00582995, -0.04133054, -0. , -0.01280909, 0.01734644,\n",
" -0.02355555, 0.04760959, 0.02267313, 0.00704623, -0.05425708,\n",
" 0. , 0.01986646, -0.02759497, -0.05245164, -0.03126373,\n",
" -0.00284979, 0.02161182, 0.02729553, -0.04157891, -0.04360487,\n",
" -0. , -0. , 0.0135207 , -0.00187147, 0.0609222 ,\n",
" -0.07456608, 0.02430784, 0.00149097, 0.02840535, 0. ,\n",
" 0.07865206, -0.04211108, -0. , 0. , 0. ,\n",
" -0. , -0. , -0.05287053, 0.0263629 , -0. ],\n",
" [ 0.02172022, -0. , 0. , 0.05970415, -0.03229865,\n",
" -0.01130804, 0. , -0.02124784, -0. , 0.02389894,\n",
" -0. , 0.01233097, 0.00665306, -0.00404538, -0. ,\n",
" 0.00455578, 0.00603381, 0. , 0.02199194, -0.03769897,\n",
" -0. , -0.00714258, 0.05062252, -0. , -0.02904983,\n",
" 0. , 0.049015 , -0.00445654, 0. , -0. ,\n",
" -0.01899791, -0. , 0.01498988, -0.00361649, -0.01428505,\n",
" 0.0200627 , 0.00454252, -0.02961647, 0. , -0.04975929,\n",
" -0. , 0.00315028, -0.03230491, 0. , -0.03524412,\n",
" 0.01989217, 0. , 0. , 0.00102865, -0.01898451,\n",
" -0. , -0.04948218, -0.03765366, 0.00472582, -0. ,\n",
" -0. , -0. , 0. , 0.01454964, -0.03137352,\n",
" -0.00582995, -0.04133054, -0. , -0.01280909, 0.01734644,\n",
" -0.02355555, 0.04760959, 0.02267313, 0.00704623, -0.05425708,\n",
" 0. , 0.01986646, -0.02759497, -0.05245164, -0.03126373,\n",
" -0.00284979, 0.02161182, 0.02729553, -0.04157891, -0.04360487,\n",
" -0. , -0. , 0.0135207 , -0.00187147, 0.0609222 ,\n",
" -0.07456608, 0.02430784, 0.00149097, 0.02840535, 0. ,\n",
" 0.07865206, -0.04211108, -0. , 0. , 0. ,\n",
" -0. , -0. , -0.05287053, 0.0263629 , -0. ]],\n",
" dtype=float32)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# gated activations\n",
"sess.run(my_net.gate(my_net.first(x), is_train=False), {x: tf.zeros(x.shape).eval()})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import tensorflow as tf
from warnings import warn
import tfnn
class ConcreteGate:
"""
A gate made of stretched concrete distribution (using experimental Stretchable Concrete™)
Can be applied to sparsify neural network activations or weights.
Usage example: https://gist.github.com/justheuristic/1118a14a798b2b6d47789f7e6f511abd
:param shape: shape of gate variable. can be broadcasted.
e.g. if you want to apply gate to tensor [batch, length, units] over units axis,
your shape should be [1, 1, units]
:param temperature: concrete sigmoid temperature, should be in (0, 1] range
lower values yield better approximation to actual discrete gate but train longer
:param stretch_limits: min and max value of gate before it is clipped to [0, 1]
min value should be negative in order to compute l0 penalty as in https://arxiv.org/pdf/1712.01312.pdf
however, you can also use tf.nn.sigmoid(log_a) as regularizer if min, max = 0, 1
:param l0_penalty: coefficient on the regularizer that minimizes l0 norm of gated value
:param l2_penalty: coefficient on the regularizer that minimizes l2 norm of gated value
:param eps: a small additive value used to avoid NaNs
:param hard: if True, gates are binarized to {0, 1} but backprop is still performed as if they were concrete
:param local_rep: if True, samples a different gumbel noise tensor for each sample in batch,
by default, noise is sampled using shape param as size.
"""
def __init__(self, name, shape, temperature=0.33, stretch_limits=(-0.1, 1.1),
l0_penalty=0.0, l2_penalty=0.0, eps=1e-6, hard=False, local_rep=False):
self.name = name
self.temperature, self.stretch_limits, self.eps = temperature, stretch_limits, eps
self.l0_penalty, self.l2_penalty = l0_penalty, l2_penalty
self.hard, self.local_rep = hard, local_rep
with tf.variable_scope(name):
self.log_a = tfnn.ops.get_model_variable("log_a", shape=shape)
def __call__(self, values, is_train=None, axis=None, reg_collection=tf.GraphKeys.REGULARIZATION_LOSSES):
""" applies gate to values, if is_train, adds regularizer to reg_collection """
is_train = tfnn.ops.is_dropout_enabled() if is_train is None else is_train
gates = self.get_gates(is_train, shape=tf.shape(values) if self.local_rep else None)
if self.l0_penalty != 0 or self.l2_penalty != 0:
reg = self.get_penalty(values=values, axis=axis)
if is_train:
tf.add_to_collection(reg_collection, tf.identity(reg, name='concrete_gate_reg'))
return values * gates
def get_gates(self, is_train, shape=None):
""" samples gate activations in [0, 1] interval """
low, high = self.stretch_limits
with tf.name_scope(self.name):
if is_train:
shape = tf.shape(self.log_a) if shape is None else shape
noise = tf.random_uniform(shape, self.eps, 1.0 - self.eps)
concrete = tf.nn.sigmoid((tf.log(noise) - tf.log(1 - noise) + self.log_a) / self.temperature)
else:
concrete = tf.nn.sigmoid(self.log_a)
stretched_concrete = concrete * (high - low) + low
clipped_concrete = tf.clip_by_value(stretched_concrete, 0, 1)
if self.hard:
hard_concrete = tf.to_float(tf.greater(clipped_concrete, 0.5))
clipped_concrete = clipped_concrete + tf.stop_gradient(hard_concrete - clipped_concrete)
return clipped_concrete
def get_penalty(self, values=None, axis=None):
"""
Computes l0 and l2 penalties. For l2 penalty one must also provide the sparsified values
(usually activations or weights) before they are multiplied by the gate
Returns the regularizer value that should to be MINIMIZED (negative logprior)
"""
if self.l0_penalty == self.l2_penalty == 0:
warn("get_penalty() is called with both penalties set to 0")
low, high = self.stretch_limits
assert low < 0.0, "p_gate_closed can be computed only if lower stretch limit is negative"
with tf.name_scope(self.name):
# compute p(gate_is_closed) = cdf(stretched_sigmoid < 0)
p_open = tf.nn.sigmoid(self.log_a - self.temperature * tf.log(-low / high))
p_open = tf.clip_by_value(p_open, self.eps, 1.0 - self.eps)
total_reg = 0.0
if self.l0_penalty != 0:
if values != None:
p_open += tf.zeros_like(values) # broadcast shape to account for values
l0_reg = self.l0_penalty * tf.reduce_sum(p_open, axis=axis)
total_reg += tf.reduce_mean(l0_reg)
if self.l2_penalty != 0:
assert values is not None
l2_reg = 0.5 * self.l2_penalty * p_open * tf.reduce_sum(values ** 2, axis=axis)
total_reg += tf.reduce_mean(l2_reg)
return total_reg
def get_sparsity_rate(self, is_train=False):
""" Computes the fraction of gates which are now active (non-zero) """
is_nonzero = tf.not_equal(self.get_gates(is_train), 0.0)
return tf.reduce_mean(tf.to_float(is_nonzero))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment