Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Simple keras implementation of Virtual Adversarial Training .
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Couldn't import dot_parser, loading of dot files will not be possible.\n",
"/usr/local/lib/python2.7/dist-packages/cryptography/hazmat/primitives/constant_time.py:26: CryptographyDeprecationWarning: Support for your Python version is deprecated. The next version of cryptography will remove support. Please upgrade to a 2.7.x release that supports hmac.compare_digest as soon as possible.\n",
" utils.DeprecatedIn23,\n"
]
}
],
"source": [
"from numpy.random import seed\n",
"seed(0)\n",
"from tensorflow import set_random_seed\n",
"set_random_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import matplotlib\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import keras\n",
"from keras.models import * \n",
"from keras.layers import *\n",
"from sklearn.metrics import accuracy_score\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Make the datasets"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import datasets\n",
"\n",
"circles = datasets.make_circles(n_samples=1000 , noise=.05 , factor=0.3 ,random_state=3 )\n",
"circles_test = datasets.make_circles(n_samples=10000 , noise=0 , factor=0.3 ,random_state=1 )\n",
"\n",
"n_poionts = 8\n",
"inds = list (np.where(circles[1] == 0)[0][:n_poionts]) + list (np.where(circles[1] == 1)[0][:n_poionts])\n",
"\n",
"X_train = circles[0][inds]\n",
"Y_train = circles[1][inds]\n",
"Y_train_cat = keras.utils.to_categorical( circles[1][inds] )\n",
"\n",
"X_test = circles_test[0] \n",
"Y_test = circles_test[1] \n",
"Y_test_cat = keras.utils.to_categorical( circles_test[1] )\n",
"\n",
"n_classes = int( np.max(Y_train) + 1 )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot the dataset"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAEACAYAAABVtcpZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XmcXFWZ8PHfU9XV1VXdSS/Zd0ACiKCgQgKCNjBI2IyCIC4juAyMovM66iijMxJ1YHRgHBfEBRkGRAzigqzKZrMYQLawaAIBAkl39nR3eq21z/vHqdt1u1K9parrVtV9vvXpT9dyu+65XXXPc/YrxhiUUkr5U8DrBCillPKOBgGllPIxDQJKKeVjGgSUUsrHNAgopZSPaRBQSikfK0oQEJFrRWS7iDw3yuvvEpFuEXk68/NvxdivUkqpwtQU6X2uA34A3DDGNg8ZY95TpP0ppZQqgqLUBIwxjwBd42wmxdiXUkqp4illn8ByEXlGRO4UkUNLuF+llFKjKFZz0HieApYYYwZE5FTgVuCgEu1bKaXUKEoSBIwxfa77d4vI1SLSYozpzN1WRHQxI6WUmiRjzD41uRezOUgYpd1fROa47h8NSL4A4DDGVOXPpZde6nka9Pj0+PT4qu+nEEWpCYjITUArMENENgGXArWAMcb8FHi/iHwKSAKDwAeKsV+llFKFKUoQMMZ8aJzXfwj8sBj7UkopVTw6Y7iEWltbvU7ClNLjq2x6fP4khbYnFZuImHJLk1JKlTMRwZRBx7BSSqkKo0FAKaV8TIOAUkr5mAYBpZTyMQ0CSinlYxoElFLKxzQIKKWUj2kQUEopH9MgoJRSPqZBQCmlfEyDgFJK+ZgGAaWU8jENAkop5WMaBJRSysc0CCillI9pEFBKKR/TIKCUUj6mQUAppXxMg4BSSvmYBgGllPIxDQJKKeVjGgSUUsrHNAgopZSPaRBQSikf0yCglFI+pkFAKaV8rChBQESuFZHtIvLcGNt8X0Q2iMhaETmiGPtVSilVmGLVBK4DThntRRE5FXiDMWYpcBHw4yLtVymlVAGKEgSMMY8AXWNsshK4IbPt40CjiMwpxr6VUkrtu5oS7WcBsNn1uCPz3PYS7V/52NAQxGKQStn7Q0NgTPbH2Qayj91/CyAy+vsHAiO3cT8OBOzvYND+RCJjv5dSpVaqIJDva2/yPKfUPunvt5l8Op3N6J3M3snInd+wd2af+zj3tX3JuAOuenZugAgE9g4QNTVQXz/5/ShViFIFgXZgkevxQmDLaBuvWrVq+H5rayutra1TlS5VYfr7IZkcmeE7Gb27ZJ8vU3c/ly9Td57L3c69rfNavudy95VOj/2+ue8fCMCePdkAEQjYwFBbC9Ho3vtQ/tXW1kZbW1tR3kvMWEWgybyRyH7A7caYw/O8dhpwsTHmdBFZDnzXGLN8lPcxxUqTqmyDg9lmnNwSPozMVHO/Mu5MOl8zjfM733P5fufj3qcThNzPu2sh+V4b62vuBAcnjc6PU2MIh23TklIAIoIxZp8aGotSExCRm4BWYIaIbAIuBWoBY4z5qTHmLhE5TUReBvqBjxVjv6q69PZCIrF3pg8j2/DzZfDBoP3tNLG4f+fLROvqSnNMg4P2WNLpbO0gt6nK/Tv3WCFbo3CONxCwNSLnmEIhW1toaCjNManqUrSaQLFoTcA/+vshHrfNO/lK+jAy08/tcHUyded3MGhLyOFw6Y+lEPH4yOA3WqCA/DUe58f5H4RCNshpE5J/FFIT0CCgSqqnxzbxpNM208st9TqZfm7mJmJL8E4mF41mS//VKpm0NYlUKjuyKbcfBEYGhtwakFNLaGz05hhUaWgQUGXLGOjutiXd0TJ+d0nfXbp3OkWnTfMu/eVmaAj6+vbuHM8NDLn9Hs7/MxyGpibv0q+mhgYBVXa6u7Ml/nwduu4Sq3uIZF2dDpOcrL4+G2SdwJCvGQlG/r+d/7XWEKqDBgFVFnp6ss0X7maL3CYeJxMKBm1GpCX94urtzd/XMlpACIXsSCP9HCqXBgHlmXg8m+k4bdbuph73hCineae+vvI6bytVLAYDA9mO59xaWW4ne20tTJ9uf6vKoUFAlZxT6k8k9i5p5jY7hMN2+GIo5G2a/S6RsE1H8Xi2huAEbdj7c4tGbUBQ5U+DgCqZzs6RE7hyR6bkZvxaoixP8Xi2L8H9WbprcE5ncl0dtLR4m141Ng0Cakql09DVZTN/9zBFGDmaJxy2pUft2K0s/f0jl+PIV6tzmvJmzfI2rSo/DQJqSiQSdi0bd+bvdPK6S/2RiA47rBZOsE8ms5857B0MWlrsfVUeNAiookombWbg7kx0Mn93qb++Xkv91aqvz9YOnPkd7u+AUwCoq4MZM0aulqq8oUFAFc3Onba92Gkndrjb+hsbta3fL+JxWxvMHV2U22cwc6bXKfU3DQKqYLt329E+TjOAwz2xSE90f9u1a+8lP9zDfyMRWzNQpadBQO2z7u5sp6C7zV8zfzUaJxg4BQb3d8aZB6IzkUtLg4CatMHB7Jo+TjXfveiYZv5qPDt3jhwuDCM7j5ubdVJgqWgQUJOyc2d2eQf3UE+n5D97trfpU5UjlbJNiU4/kvv75CxHoYWJqadBQE1IT8/IGaOQrcY7Hb65a9D/9a9/5dZbf08kUsdHPvIRZmuEUHn092c7kJ3RRE7N0lmKQtcmmjoaBNS4duwYWfp3Mv9QyM7szTfO/+GHH2bFirOIxz9KTU0n06c/wLPPPsa8efNKfwCqInR1ZYeWJpNpfvzjK3nggfuZM2cW3/zmKpYtW+p1EquSBgE1qr4+W0JzFnhzlnF22v3HKti/7W0n8PTTFwHnAVBT8//43OeiXHHFf5Ym8apibd8OX/jC5/jNb54mFruEQOB5Ghq+R1vbUxx00DydX1Jknl9jWJWnXbtsqcxd+nc67Robx78mbVdXN3DA8ONU6g3s3r1+ahOtqsLs2YZbbvkZicSrwGyGhk4jkXieO++8nRkzLiQW0+Gk5ULn+lWprVuzF253AoDT9LNgwcQuSn722acTjV4CvAY8RTT6P5x11mlTnHJVLUQCgGvGISlAiMdt/9S2bR4lTI2gQaDK9PZCe7tdQz6Vss8FAtmJPJNZAOzyyy/l/POPoLHxWGbNOofvfOcrnHHGGVOTcFVVRISLL76YaHQl8CuCwX8nEvkzp5zyXoyx383+fvtd7e/3OrX+pn0CVWT37uz1Z93NP5EIzJ3rdeqU3wwNDXHVVT/i9tsfYP78WfzHf3yVUGjR8Mx0sN9R5zrSulz1vtOOYcX27dnSv7vzV08uVW5GK6zU1+sclX2lQcDntm61AcA9a9O5EEjuuH+lykFvb3bGuvO9da5mprXWydMg4GMdHdlFvYzJnkg6lF9Vgi1b7PwVdwEmErGDF9TEaRDwqc2bs7N/nSp1IDDIM8/8gcHBQU488UTmarFKlblt20Y2ZTrXKli40OuUVQ4NAj4Ti9kZwO4AUFsLIr2sWNFKR0cDMJNgcA2PPHIvhx12mNdJVmpMO3faJiJ3n1ZdHSxa5HXKKkMhQUCHiFaYwcHs6o3O+izOuj/XX/89XnvtEPr62ujr+w09PZdy4YVf8DrJSo1r1iy7vlAoZB+n0/Y7vnmzt+nyAw0CFSQWy64A6qzhXltr1/1paYHXX99CPL4csAUCY5bT0bHF20QrNUEzZ9rCTDhsv9vptP2uayCYWkUJAiKyQkTWi8hLIvLlPK+fLyI7ROTpzM/Hi7Ffv9mxI1sDELEnS0tLdvG3k056B9Hoz4AdQIJw+EpaW9/hZZKVmhTn++wEgqEhDQRTreAgIHZu+FXAKcCbgA+KyCF5Nl1tjHlr5ud/C92v32zebE8GdwBobrZVaMeHPvQhLr74TGpqFhMMTuf44/v44Q+v9C7RSu2Dpqb8gaC93euUVaeCO4ZFZDlwqTHm1MzjSwBjjPm2a5vzgbcbYz47gffTjuEcHR3Z0RNOAJgxY2QAcEsmk6RSKSKRSGkTqlQRdXXZH2dSWTBoJ5Tp8NG9ed0xvABwV9baM8/lOktE1orIr0REB39N0JYtIwOA0wcwWgAACIVCGgBUxWtutt/1UCjbRzAwYCdHquIpxlLS+aJPblH+NuAmY0xSRC4CrgdOGu0NV61aNXy/tbWV1tbWwlNZgbZvzy4FDfZkmD7dnhxK+UFLi838e3qys4v7+23/mJ+XmGhra6Otra0o71Ws5qBVxpgVmcd7NQflbB8AOo0xea5lpc1Bjs5OWxVOJOzjmhobAObM8TZdSnlh+3Y7jyCZzNaIm5u1QOTwujnoCeBAEVkiIrXYy1DdlpNA97TVlcDfirDfqpVI2KuBJZPZGZT19RoAlH/NmWOXkwgEbP+Ac46owhUcBIwxaeAzwD3AX7GjgNaJyNdFxFl8/p9E5AUReSaz7QWF7reabd+evRxkMGi//PPne50qpby1YMHIQBCP64ihYtBlI8rMtm22/TOVys4G3n9/r1OlVHmIx+1giVjM1pKdfjK/15K9bg5SRbJnT7Yj2FkQrrHR61QpVT7CYZvp12SGtKRS9toEvb3epquSaRAoI8766mBrAfX1ekEYpXLNmGHPjWDQNgslk3YQhdo3GgTKxLZt2X4AEbuCol4TQKn85s+3tQKnfyAWs31pavI0CJSB/n5bpXXPB9Chb0qNrbExu+poMmnPoXjc2zRVIg0CZaCzc+Rw0Gh07BnBSilbUHJGCxljz6EdO7xOVeXRIOCxrq7s5SGdSTA6HFSpiVmwwJ4zgYA9h+JxnT8wWRoEPOZMCgM74mHaNG/To1SlaWiwNWhj7MCK7m6vU1RZNAh4aNcu+6V1XyFs1iyvU6VUZZk9e+SFaOJx28SqJqYYC8ipfeTuDNY5Ad7awAZ66GGAARIkSJIknbmZzA1AMrcAAWqoIUSIWmqJEqWRRg7kQI+PxJ+mT7cFqnjcnlM9PTq8eqI0CHikszO7KmIwaEsyTXmX1FPFtpvdbGQjXXTRTz9x4qRIkSTJEEPDGf+mtZt46rdPU1sX4pgLjqFpfhPiWjTXCQjBzK2GGp7macKEqaeeJppYylIa0eg+1VpabMbvXHvAaRbSc2p8umyER157za6Nnk7bjq05c/QLOxWGhoZ49tln2ZzYTOSICP3h/hEZv7uk787gN/xpAz889xqSF76VQFeMyG0v8W+PXkLLopYRNQPjusHIwBAiRJgwUaLMYAYHciDz0MkfU6Wz016DO5HILrq4ZInXqSqNQpaN0JqAB9xrozt9ARoAii8Wi7HsPcfx0sbXIRqiYQg+f+9naZjbMJzpO806NdQMZ9whQlz59R+Q/OGpcO5hDAGD0RCP/+BJLvqvTw43Ezm1hzRpkiRJZW5p0iRIECPGnqE93PdfD/DYL5+itq6Wj375A5x91tkcxVFe/3uqTkuLHWiRStlzK5GwBa1o1OuUlTcNAh5wFohz5gU0NHidourzJ/7E1f/9I9Y1DJBc/2kICIlL7mP1v/yGf/z5J6illjBhIkSop57pmdt85tNCC1/p/QYszjbjDC2eTtP6ZlawYq997WIXW9lKDz300ks//cSIMcggt115B/f++ikS15wGXTG+87GrSTYleeXEV5jHPN7Fu0r5b6l69fW2XyCdzi4noUFgbBoEPBCLZZeHCIVg5kyvU1Q9HuRB2mmnhx5eemkjyTOWQtAOghtaeTA7P/cAi1nMbGZzCIcwjfxjcs9beTb/88VfMHDNCugcJHrlE5z7k3/Ku+3MzM2tiy5e4iW+9qvLSXxvBRxtr6ia/Jdjefg3j7LoxEV00UU77SxmMcdzfBH/C/41e/bIvoFYzOsUlT8NAiW2c6etBQwN2RFBdXVep6g6PMdz/I2/0UkngwySIsXCw+aw7pa/kfzwm6E2SOimv/Kuw4/ndE4f9/2+/tWvEY/HuOGMm6gN1/KNb17JqaeeOuH0NNPMMpYxOzqLrTv7h5+Xnf1EInWkSBEnzgAD9NDDFrZwOIdzKIfu0/GrrHDYNgWlUvZn92676JzKTzuGS2zTJrvs7dCQ7RBeuNBWYdW+u5u72cxmeuklQYIAASJEiCaj/Pi8a/nrmvUE62rZb/YCHrzrXmaUMEe4++67OfuCDzH4z28j0BkjesPf+O6aKwgcEKCHHmLEGGKIWmqZznSWsIR38+6Spa8a9fTYaw4kErag1dAAixd7naqpVUjHsAaBEtuwIVtFra+HAw7wNj2VbjWr2cY2BhhgiCHChGmkkXnM4528k+lmOhs3biSRSLB06VKCwWDJ07hmzRpu+vXNRMJ1fPrCf2T//fdnJztZwxq2sY097CFJEkGop555zONczi15OqvJK6/YTmGw6wsdWOXTNzQIVIhdu7JD2Gpq7AJYuk7QvlnHOh7ncbaznRgxggSpp545zOFtvI2DOdjrJE7IC7zAszzLdrbTTz8GQx11zGUux3M8+6OXldsXHR12noB7CHY1r8yrQaBCtLdnh7DV1dnFr3Rk0OQ9z/M8yqPsYAdJ0tQQoJlmDuAATuM0r5O3T+7kTl7lVbrpJsUQtdQwl7kcx3EVE9DKSXd39hodzmz8hQu9TtXU0ctLVghnnSARHRq6r17mZdawhq1sJUaaADCLWRzFURUbAABO53SO5EhmMIMAMEiKDjp4iIfYylavk1dxmpps5i+SnUGs8tMgUCLJZHZUUCCQvRiGmpw22tjKNmJAEJjLXE7iJJaxzOukFew4juN4jmc2swkCgxg62MIf+IPXSatIThAwJrtGl9qbBoES2bPHtk9C9vKRanJ+wS9op4MB7Bd3NjM5jdNYylKvk1Y0h3M4J3ACM2lBgAFgE5u4mZu9TlrFcS4/aYw993Rl0fw0CJSIc/1gyF49TE3cfdzH67xOX+ZxCw0cz/EsYpGn6ZoKh3AIy1lOM1EM0AdsZCMP8qDXSasokYg918Cee3rpyfw0CJSIM4MR7BdTLx4zOetYRxdJ0kADwmEcxpt5s9fJmjJHcRSHcij1QALoJM4LvOB1sipKU5OtdTv9As7Fm9RIGgRKJJ221VIRW0VVE/c7fsdOOokBIWABCziFU7xO1pQ7ndOZx1xqgUFgB7u4kzu9TlZFqcmsieA0Cam9aXZUIk4tALJfTDUxr/EaPYABphPkTM70OkklcxInMQ1hCOgBXuEVr5NUUdwFLg0C+WkQKAGnP8CpCXgwabVitdFGFz3EgVpgPvOZzWyvk1UyB3Igc5lLCFsb2E0Xj/GY18mqGO4RQu6CmMrSIFACg4PZL6AGgcnZyEb6gDQQAZaz3OMUld6RHEkE+z/ox86VAEgmk9x9993cfPPNdHR0eJnEshUI2HMO7DnY2+ttespRUYKAiKwQkfUi8pKIfDnP67UislpENojIoyJS5cs5jeRcOwDsF1Kbg8bX29vLjh076DS2L0CAaYR9ucrmURxFAzWZSWT28pjxeJzjTjmRc7/2af7hV9/ijW89nL/85S9eJ7XsuM81nS+QX8HZkYgEgKuAk4AtwBMi8ntjzHrXZp8AOo0xS0XkA8B/AecVuu9K4a6Gasfw2Iwx/PMlX+Tqq65GaoPUNIdIDiSRgHDyxSdivmIQ2afZ8RWtiSY2s4sk0Ecf1113Hc+HdzN430ftF2r185z/mQtZ95e1Xie1rAQC9scZmKH9AnsrRnZ0NLDBGPO6MSYJrAZW5myzErg+c//X2IDhG05/gEODwOhWr17Nz/54C8nN/0Ri1XEMNNSS/PMnSLRdwH23PMGPr/mJ10n0RBNNBIEhIE6Kp9ufYXD5nOyX6dhFbOvY4mUSy5L7XNN+gfyKkR0tADa7Hrdnnsu7jTEmDXSLSEsR9l0RnADg/NYgMLo/P/EY/R8+BFqi8IdX4LKT4A0tcNBM4l89jt/84Xavk+iJCBEEO0IqDRxyzMFEf7EeOnpgaIjQfz/OsmP8118yHqfSmHsOqqxitE7nq5vn/qtzt5E82wxbtWrV8P3W1lZaW1v3MWnlxZm44sPWjAlbuuQAIvfex+Dnh6C5Dl7aPfyabOhkZtM8D1PnnYCrvDYELD99OZc8/xm+ceA3QIQjlr2NG2/5X+8SWOacc+6+++7issu+QX9/P+ed9z4uu+xr1FRgJ11bWxttbW1Fea+Cl5IWkeXAKmPMiszjSwBjjPm2a5u7M9s8LiJBYKsxJu84v2pcSnrbNnuJu1TKLhw3f76dzaj2Fo/HaT39ZF7Y+RrphiCDazuQv38zMmSou+1l1j78JEuXVs9aQRN1MzfzEOuJAfOAs3kPR3IkyWSSWCzGNJ2Cnldnpz3/kklYt+5xPv7x9zA4+DNgAdHo5/j0p4/niisu8zqZBStkKelihMAngANFZAmwFdvh+8GcbW4HzgceB84BHijCfitGbslf2yVHFw6HefgPD/Dwww/zcv/LPLnwKdbe9ypBEd7/5IdZutB/AQCghx6GsO23YQIcgL0kXSgUIqRL0o7K3R/3xz/+jsHBT0NmsuHAwNXceONZVREEClFwEDDGpEXkM8A92O/otcaYdSLydeAJY8wdwLXAz0VkA7AbH40MgpFjlbVzanw1NTWccMIJnMAJdNJJ4Ih5o7cd+kQ33QxhT9gIERpp9DpJFcF9rkWjUWpqtruGie4gEol4kayyUpTGMGPMH2Dk5Y+MMZe67sfBvxdNDQZH1gZ0rPLENdJILTEGgT34c6bPAAP0MoABwqABYBKcoaEA5577SX75y6Pp6akhnV5AJPJdvvWt73mbwDKg41RKIBweWRPQscoTN5e5RGF4bf3f8luPU1R6t3EbA9j/QRS7gJ6aGPe5NnfufJ566nG+9KXpfPrTW7nrrps499xzvEtcmai8bvEKVF+fHRaqQWByjuZo1rGOXhLEgFd51eskldwmNhHDrp3USF1FX0az1Nw1ARHYf/8FXH75N71NVJnRmkCJuJuENAhM3HzmM495NGDHFHfRz63c6nWySuYWbqGLQQzQgP1/qIlzL+Gua3blp0GgRJzOYa0JTN7beTtNhAlhm4Se732elR97PzOXzOOQo97MQw895HUSx5Texw98Axt4lVeJYfsCmolwLMcWNW3Vzul/0yAwOg0CJRIKjbzotQaCiTuMw9iP/XBGwv/fx1dzV+pldj9wLi9+5VBOPfs9bNiwwdM05tPe3s6Rxy+jNlxL09yZ3Hrr5Gow93M/XcQAmAbsz/68gTdMQUqr08BAdohoIKALN45Gg0CJhEIj+wW6urxNT6U5j/OYyyzqhobYeNuLpH58ul1O4n1vxKw8mPvvv9/rJO7ltHPfy/Mn1TMU+yp7fv8+Pnzhx1i3bt2E/vZGbmQL20ljO4PnM5dz0E7Myejrs0HAWbSxttbrFJUnDQIlEo1mg0A6DbGYt+mpREdyJLOkkZr6ELT32CeNIbC5l4aGBm8TlyMWi/HXJ58l/bXjoSYIyxYiKw7k0UcfHfdvV7OajWwkge0Mnk0TR3HUlKe52sRiI2sCOqk6Pw0CJTJtmq2OBgL2i5lIeJ2iynMMx3CoHMrpl51KzSk/h28+SM25tzB95wBnnXWW18kbIRwOE47WwV932ieSaeSFHcyePfZV0X7JLzP9AEOEgBlM4028ibfy1qlPdJVJJu25FgjY/gCdF5aftpKVkPtSd8mk16mpTCtYQepTKWa/YRYvtP2NhmMX8Y7rjuXX0V9zAiewiEVeJxGwa7n89Oofc9HJn8GceTDBtds5dsnhnHZa/uGdG9jAGtawhS3ESFIDNDGdgzmYv+PvSpv4KhCPZzuFtSlobAUvIFds1biAnGPbNtsXkEjYL+WMGTBOwVCN4i7u4mVeZg97SDNEHWFaaOFADuRETvQ6ecPWrl3Lo48+yrx58zjzzDMJ5hmich/38QqvsJvdJEgSJEATTSxlKStY4UGqK5/fzrVCFpDTIFBCAwPQ3m5LKYEANDTAkiVep6pyPcADbGADXXSRJEkNNdRTzxzm8BbewiEc4nUSx/QCL/A8z7OTnfTRR5o0IULMYAYHczDv5J1eJ7FivfZatmO4rg72398OzqhWXq8iqiYoGrVNQomE9gsUw4mcSAstPM/z7GIXMWL0008HHexhD+tZz+EcXnbDKl/iJV7gBXaxi156iRNHEKYxjVnM4giO4E28yetkVrREItshXFtb3QGgUFoTKLEtW6C7O3ttgZkzYdYsr1NV+e7gDrawhT76SJAgSJBaammggRZaWMISzztXn+RJXud1uuiin37ixDEYQoSYxjQWsECXhCiCbdvsdQScc6y5GebO9TpVU0ubgyrIwAB0dMDgoC2l1Nfbqqoq3FrW8hIv0UUXAwyQIkWAALXUEiFCPfW00MJCFpaspP0CL9BOO1100UcfgwySIDGc+UeI0EILh3Kolv6LZONG6O+3NYG6Othvv+qvCWgQqDCvvWa/pENDtqq6YIHtH1DF8SiPsolN9NBDjBgpUhgMQYKECVNLLVGiNNBAM83MYEbR+g/WsY5d7KKbbvroI5a5JUiQJo0ghAgRJkwTTSxmMctYVpR9K+jpga1bbXOQn/rdNAhUiL6+PtavX08gMItodAmJhO0jmDYNFpXHyMaq8iRPspnNw8EgSZI0dr2OIEFqMrcQIWqpJZy51VFHbebmvC6Zy2QPMUQqc4sTJ0GCeM4tRYokSVKkGMJe1cRdI5nOdBaz2PPmqWq0aZPtEE6nbQFr1ixoafE6VVNPg0AFePbZZznxxNNJpWaSSLTzwQ9+ks9//lvDVVYfXja3ZDaykVd4hW666aefBInhDNpkrlkmCAHXzXns/u1sazDDfzuUc8t9vxpqqKOOCBGaaWYpS1nMYs/+F9UsHrdBIBaz83GiUTjgAK9TVRoaBCrAAQcczsaNXwL+HugkGj2G73//Ko4++mRqamD6dNsspKbWBjawhS300MMAAyRIDNcQnIx8NO5AkPu8k+k7HdJOqb+JJuYzv+xGKFWj9nbbHJRO2z6Apqbq7xB26BDRMmeM4fXX15G9wmYL6fTJvPba31i+/GTSadtRrKbe0szN8SIvDnfaOs07TnOPU7J31wDEdQtmbqHMLUx4eDTSQRzk1SH6UiJhzyFnmYhQyD8BoFAaBEpARFiy5I1s3HgL8BGgi5qae3lUcqxgAAARMklEQVTTm1YSDNqhbKmUHTWktYHSOnjkpbGH7WIXPfQMt/G72/ZrqCFMmEYamcGMUiZXjWL79uwyEcGgHXWnJkabg0pk7dq1nHTSGaRSs0gk2rnwwo/zve99m1desW2YTt/AggW60JVSk9HXZ0cExeO2LyAS8U9fgEP7BCqEMzpo5syZ7LfffoCd2LJnD8MjhRoadKSQUpPx+ut2/o27L2DOHK9TVVoaBCrcxo3Z9kxnsasZ2sqg1Lh27swuFBcI2BFBmfKVrxQSBPR6AmWgoSF7+clUCnp7vU6RUpWht9eeMyL2HNILx0yeBoEyMGsWhMPZC87E43aNIaXU6NrbRy4UV1enNeh9oUGgTMyYkQ0E6bRt4+zs9DpVSpWn3bttE2o6bWsB4XB1Xy9gKmkQKBP19dmlpp0rj+3Z43WqlCpPPT32HBGx50x9vQ0EavI0CJSRuXNtldYJBPG4rfIq5Vf9/f0kc67Funlz9locgYAdEuq30UDFVFAQEJFmEblHRF4UkT+KSOMo26VF5GkReUZEbi1kn9Vu1iw7QkjE9g8MDtqJMEr5SVdXF+94x7tpappJNDqdr3/9csDOB3A3A9XVwfz5Hie2whVaE7gEuM8YczDwAPCvo2zXb4x5qzHmSGPMewvcZ1WLROwIB+fC2M5oIe0fUH5ywQUX8+STB5BK9ZJKvcIVV9zAz3/+++H5ACL2HJk2zc4QVvuu0CCwErg+c/96YLQMfp/Gr/rVzJm2f8C5EIbTP9Df7226lCqVNWvWkEj8C3Zlm/n093+UBx9cQzJpm0prauw5oqOBCldoEJhtjNkOYIzZBox2ocSwiPxFRNaIyMoC9+kL8+bZqm4g8wklErBrl7dpUqpU5s6dDzyaeTREXd3jzJkzH2NsyT8ateeIKty4C8iJyL2Au9tFAAP82yT2s9gYs01E9gceEJHnjDEbR9t41apVw/dbW1tpbW2dxK6qx6JFdkp8LJadP7B5sy4roarfddd9nxNPPB24laGhDpYsqeGccy4ang/g94UW29raaGtrK8p7FbRshIisA1qNMdtFZC7wJ2PMG8f5m+uA240xvx3ldd8tGzGWnh7bHxCL2cfO1PiFC71Nl1JTraOjg9/97iFqaho45phTqKurpa7ODp6IRr1OXXnxbO0gEfk20GmM+baIfBloNsZckrNNEzBgjEmIyEzgz8BKY8z6Ud5Tg0COri7o7rY1AcgulaujIlQ16+jIXovb6QhubrYLxKmRvAwCLcCvgEXAJuAcY0y3iLwNuMgYc6GIHAP8BEhj+yD+xxjzf2O8pwaBPHbtyk6QcdpFNRCoatXRwYiRQM7qoH64XvC+0FVEfWLnTjtcVAOBqnRPPfUUP/rRdRhj+NSnPsbb3/724dfyBYDp0+2oOZWfBgEfyRcIIhHtI1CV47HHHuOkk85kYOCLgBCNXsG99/6eY489lvZ22//lDgANDbYfQI1Og4DP7Nxpr6bkBAJn6rwGAlUJzjzzg9xxxzuBT2We+Sknn3wv1157y/BIOMhOBtMawPj0egI+M2tWdlaxe3mJTZv0WgSq/A0OxgF3724Tvb2J4QDgrAqqTUCloUGgQs2caU+ScDgbCOJxO5xUl5hQ5eyiiz5ENPqvwB+Be4hELuHccz/E0JCt1dbWQmOjzgYuFW0OqnB79tifeNw2DYFtR62v1/XVVfm64YYbufzyHzE0BJ/4xEW8730fJRCwhZrGRlvAUROnfQI+199v5xK4O9RqanSFRVW+tmyx39dUyj4OBm0AcK6ypyZHg4AC8p9YtbW2VNWYd5FvpUqru9vOd0kksu3/zsAGLbDsOw0Catj27baT2H0djtpaO81eh9kpL+3YYcf/p1LZAOBcFUybLgujQUCN0Nlph5A6F+EGe7LV1toOZa1uq1IaGLDNlfG4zfydYc21tTYA6CzgwmkQUHsZHLTBIB63/QSQPfF0HXZVKrt22SDgbv5xt/87F09ShdEgoEblNA+lUrYE5j4Jm5p0NUY1Nfr7s6PWnD4qZwZwJKLNP8WmQUCNqbs72zzkro6HQnYEkZ6Qqph27MgWPJzZv05zZEODDlKYChoE1IRs354dPeSumjtNRM3NXqdQVbKuLlsDcAobkO381dL/1NIgoCZsz56RtQJnlqYTDBobbWedUhPV12eXK0kk9m52DIXsEic6+WtqaRBQk7Zjh22vTSazpbZAwJbanP4CHUWkxjI4aAsViUR2MUNn3L8zWVGHJZeGBgG1T5yT2BlB5ASDYDAbDFpa7H2lHIlEdsin871x+pncgw7q6rxOqX9oEFAF6e4eOYzP3V/gnNTNzbZqr/wrHh854scpNIAtKNTUaMevVzQIqKLo7MzONnb3F7hLeNOm2U4+5R8DA9k2/3Q6O+8Espm/DizwlgYBVVS7d2dHEaXTI9t6nQ7k+npb6lPVq7c3O9rH3VzobvePRHTiYTnQIKCmxK5dI4eUOh+LEwycTEBLgNWlszO7Iq0z2sfhfO51dZr5lxMNAmpKuWsG7mDgLhE614LVGciVaWBgZKnfXfJ3mgSdyYW61k/50SCgSqKra+RMUPeoEKcj2T2qSJW3dHrvjl53e7+7xldXpzW+cqZBQJVUT092SWD3EEEY2ZHsnnOgykd3d3aOSO7nJ5Kd5RsK2ZrdtGnepleNT4OA8kQikb1AiJOZjBYQQiHboawBofSMGVnid2f8Tqe/u2nPuRCRrvBZOTQIKM/19GSbitxjyN2ZTDA4coRRJKLDTafKwIDtx0km89fYIJvxO0uMRyJa6q9UGgRU2TDG9h0468g4Q0xHy3ycZiOn9Cn79DVWqVS2Y9ed6Y+V8TvNddrWX/k0CKiyFI/bxcWc0uhYmRJkZyi726N1lnJ+iYQt7Tvt+u5M310Lczrt3TPAnSWdtbmnemgQUGUvkRg7ILj7EZzf7hFH7szLj3p7927aGe1/6G7jz834NahWJ8+CgIi8H1gFvBE4yhjz9CjbrQC+CwSAa40x3x7jPTUI+EBPz94dlbD3SBXnd27G5m5KcmoN1aC/P9uv4pTwjdk7w3fWd3K4m3mcMf1OE5vjmmuu5corf4KI8KUvfYqPf/yCkh6bmjqFBIFC14d8Hngf8JPRNhCRAHAVcBKwBXhCRH5vjFlf4L5VBXNnTrHYyDWL3CVdyGaCIva1ZDKbATo1hu7ukbUHdyk4FCqfZbGdoZnucflO5p77G/buYHcEg3u37zuXbsx3rDfeeBOf+9x/MjBwDWD47Gc/SX19lA984NwpP2ZV3goKAsaYFwFExuzOOxrYYIx5PbPtamAloEFAAXYiknvZ4UQiGxRym47cK1caM3Jyk7vm4PzO9+NuJ3ean3Lvu3+c59wVVHdnd26m7bzm/sn3vLO9837u93afUe4mstzgFolMrG3/mmtWMzDwLeAEAAYGLuOaa1ZrEFAF1wQmYgGw2fW4HRsYlMqrtnbvjM098sUpQbszWBjZNp7bophbTMltShnt/mjvk5tpj3c/93Hu+zgZvXM/3wiqUGjfr/pWX18H7HY9s5toVBf8VxMIAiJyLzDH/RRggK8aY26fwD7y1RK00V9NSn393hlgPL73kgejlcQduZ2oMLJ2kRss8mXc+V7L97qzTW4Nw30/X4YfDNqaUTFH71x66Rd48MEzGBjYBQwRjX6Pf//3u4u3A1Wxxg0CxpiTC9xHO7DY9Xghtm9gVKtWrRq+39raSmtra4FJUNUoHB69rd+ZKJUvOLhrDqMFiVyjPZ+b8eeW6J1t8mX2tbWl66tYtmwZjzxyD9dccz0iwoUX3s9b3vKW0uxcFV1bWxttbW1Fea+iDBEVkT8BXzTGPJXntSDwIrZjeCvwF+CDxph1o7yXjg5SJeVuZsrXXp+Pu2Pa+e38lEsntPIPz0YHich7gR8AM4E7RGStMeZUEZkHXGOMOcMYkxaRzwD3kB0imjcAKOWFfH0QSvmFThZTSqkKV0hNIDD+JkoppaqVBgGllPIxDQJKKeVjGgSUUsrHNAgopZSPaRBQSikf0yCglFI+pkFAKaV8TIOAUkr5mAYBpZTyMQ0CSinlYxoElFLKxzQIKKWUj2kQUEopH9MgoJRSPqZBQCmlfEyDgFJK+ZgGAaWU8jENAkop5WMaBJRSysc0CCillI9pEFBKKR/TIKCUUj6mQUAppXxMg4BSSvmYBgGllPIxDQJKKeVjGgSUUsrHCgoCIvJ+EXlBRNIi8tYxtntNRJ4VkWdE5C+F7FMppVTxFFoTeB54H/DgONsNAa3GmCONMUcXuM+K1dbW5nUSppQeX2XT4/OngoKAMeZFY8wGQMbZVArdVzWo9i+hHl9l0+Pzp1JlzAb4o4g8ISL/UKJ9KqWUGkfNeBuIyL3AHPdT2Ez9q8aY2ye4n2ONMdtEZBZwr4isM8Y8MvnkKqWUKiYxxhT+JiJ/Ar5gjHl6AtteCvQaY74zyuuFJ0gppXzGGDNes3xe49YEJiFvAkQkCgSMMX0iUg+8G/j6aG+yrweilFJq8godIvpeEdkMLAfuEJG7M8/PE5E7MpvNAR4RkWeAx4DbjTH3FLJfpZRSxVGU5iCllFKVydNhm9U+2WwSx7dCRNaLyEsi8uVSprEQItIsIveIyIsi8kcRaRxlu7SIPJ35/G4tdTona7zPQ0RqRWS1iGwQkUdFZLEX6dwXEzi280VkR+bzelpEPu5FOveViFwrIttF5Lkxtvl+5rNbKyJHlDJ9hRrv+ETkXSLS7fr8/m3cNzXGePYDHAwsBR4A3jrGdq8CzV6mdaqODxuIXwaWACFgLXCI12mf4PF9G/hS5v6XgW+Nsl2P12mdxDGN+3kAnwKuztz/ALDa63QX8djOB77vdVoLOMbjgCOA50Z5/VTgzsz9ZcBjXqe5yMf3LuC2ybynpzUBU+WTzSZ4fEcDG4wxrxtjksBqYGVJEli4lcD1mfvXA+8dZbtK6uyfyOfhPu5fAyeVMH2FmOh3rZI+rxGMHXreNcYmK4EbMts+DjSKyJwxti8rEzg+mOTnVykZazVPNlsAbHY9bs88VwlmG2O2AxhjtgGzRtkuLCJ/EZE1IlLuAW4in8fwNsaYNNAtIi2lSV5BJvpdOyvTVPIrEVlYmqSVTO7/oIPKOd8manmm6fVOETl0vI2LOUQ0r2qfbFaE48sXtcumt36M4xu/rTFrcebz2x94QESeM8ZsLGY6i2gin0fuNpJnm3I0kWO7DbjJGJMUkYuwNZ5KqelMRFmfb0XwFLDEGDMgIqcCtwIHjfUHUx4EjDEnF+E9tmV+7xSR32GrtWURBIpwfO2Au2NxIbClwPcsmrGOL9NBNccYs11E5gI7RnkP5/PbKCJtwJFAuQaBiXwem4FFwBYRCQLTjTHjVdHLwbjHlnMc12D7fapJO/azc5TV+VYoY0yf6/7dInK1iLQYYzpH+5tyag4adbKZiDRk7juTzV4oZcKKZLR2uieAA0VkiYjUAudhS2OV4Dbggsz984Hf524gIk2Z40JEZgLHAn8rVQL3wUQ+j9uxxwtwDrbjvxKMe2yZYO5YSXl/VqMRRj/fbgM+CiAiy4Fup0mzgox6fO7+DRE5GjsNYNQAAHg+Oui92FLVILAVuDvz/Dzgjsz9/bGjGJ7BLl19idc99MU8vszjFcCLwIYKO74W4L5M2u8FmjLPvw34aeb+McBzmc/vWeACr9M9gePa6/PAznI/I3M/DPwq8/pjwH5ep7mIx3Y5tpD1DHA/cJDXaZ7k8d2ELdnHgU3Ax4CLgAtd21yFHSX1LGOMSizHn/GOD7jY9fmtAZaN9546WUwppXysnJqDlFJKlZgGAaWU8jENAkop5WMaBJRSysc0CCillI9pEFBKKR/TIKCUUj6mQUAppXzs/wMp7kRlWyjUEwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fb737078bd0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(X_test[:, 0], X_test[:, 1], c=Y_test, s=20 , cmap='winter' , edgecolor='none' , alpha=0.005)\n",
"plt.scatter(X_train[:, 0], X_train[:, 1], c=Y_train, s=20 , cmap='winter' , edgecolor='k')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def plot_model_predictions( m ):\n",
" \n",
" xx, yy = np.meshgrid(np.arange(-1.4, 1.4, 0.1),\n",
" np.arange(-1.8, 1.4, 0.1))\n",
"\n",
" Z = m.predict(np.c_[xx.ravel(), yy.ravel()]).argmax(-1)\n",
" Z = Z.reshape(xx.shape)\n",
"\n",
" plt.contourf(xx, yy, Z, alpha=0.3, cmap='Greens' )\n",
" plt.scatter(X_test[:, 0], X_test[:, 1], c=Y_test, s=20 , cmap='winter' , edgecolor='none' , alpha=0.005)\n",
" plt.scatter(X_train[:, 0], X_train[:, 1], c=Y_train, s=20 , cmap='winter' , edgecolor='k')\n",
" \n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model without VAT"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"model = Sequential()\n",
"model.add( Dense(100 ,activation='relu' , input_shape=(2,)))\n",
"model.add( Dense(2 , activation='softmax' ))\n",
"model.compile( 'sgd' , 'categorical_crossentropy' , metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/1\n",
"160000/160000 [==============================] - 13s 83us/step - loss: 0.2050 - acc: 0.9870\n",
"('Test accruracy ', 0.9059)\n"
]
}
],
"source": [
"model.fit( np.concatenate([X_train]*10000) , np.concatenate([Y_train_cat]*10000) )\n",
"\n",
"y_pred = model.predict( X_test ).argmax(-1)\n",
"print(\"Test accruracy \" , accuracy_score(Y_test , y_pred ))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot the model outputs"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python2.7/dist-packages/numpy/ma/core.py:6442: MaskedArrayFutureWarning: In the future the default for ma.minimum.reduce will be axis=0, not the current None, to match np.minimum.reduce. Explicitly pass 0 or None to silence this warning.\n",
" return self.reduce(a)\n",
"/usr/local/lib/python2.7/dist-packages/numpy/ma/core.py:6442: MaskedArrayFutureWarning: In the future the default for ma.maximum.reduce will be axis=0, not the current None, to match np.maximum.reduce. Explicitly pass 0 or None to silence this warning.\n",
" return self.reduce(a)\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fb7392b4290>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_model_predictions( model )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model with VAT"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def compute_kld(p_logit, q_logit):\n",
" p = tf.nn.softmax(p_logit)\n",
" q = tf.nn.softmax(q_logit)\n",
" return tf.reduce_sum(p*(tf.log(p + 1e-16) - tf.log(q + 1e-16)), axis=1)\n",
"\n",
"\n",
"def make_unit_norm(x):\n",
" return x/(tf.reshape(tf.sqrt(tf.reduce_sum(tf.pow(x, 2.0), axis=1)), [-1, 1]) + 1e-16)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"network = Sequential()\n",
"network.add( Dense(100 ,activation='relu' , input_shape=(2,)))\n",
"network.add( Dense(2 ))\n",
"\n",
"model_input = Input((2,))\n",
"p_logit = network( model_input )\n",
"p = Activation('softmax')( p_logit )\n",
"\n",
"r = tf.random_normal(shape=tf.shape( model_input ))\n",
"r = make_unit_norm( r )\n",
"p_logit_r = network( model_input + 10*r )\n",
"\n",
"kl = tf.reduce_mean(compute_kld( p_logit , p_logit_r ))\n",
"grad_kl = tf.gradients( kl , [r ])[0]\n",
"r_vadv = tf.stop_gradient(grad_kl)\n",
"r_vadv = make_unit_norm( r_vadv )/3.0\n",
"\n",
"\n",
"p_logit_no_gradient = tf.stop_gradient(p_logit)\n",
"p_logit_r_adv = network( model_input + r_vadv )\n",
"vat_loss = tf.reduce_mean(compute_kld( p_logit_no_gradient, p_logit_r_adv ))\n",
"\n",
"\n",
"model_vat = Model(model_input , p )\n",
"model_vat.add_loss( vat_loss )\n",
"\n",
"model_vat.compile( 'sgd' , 'categorical_crossentropy' , metrics=['accuracy'])\n",
"\n",
"model_vat.metrics_names.append('vat_loss')\n",
"model_vat.metrics_tensors.append( vat_loss )\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/1\n",
"160000/160000 [==============================] - 22s 137us/step - loss: 0.3167 - acc: 0.9779 - vat_loss: 0.0746\n",
"('Test accruracy ', 1.0)\n"
]
}
],
"source": [
"model_vat.fit( np.concatenate([X_train]*10000) , np.concatenate([Y_train_cat]*10000) )\n",
"\n",
"y_pred = model_vat.predict( X_test ).argmax(-1)\n",
"print( \"Test accruracy \" , accuracy_score(Y_test , y_pred ))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot the model outputs"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fb7362adc10>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_model_predictions( model_vat )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@fahad7033

This comment has been minimized.

Copy link

@fahad7033 fahad7033 commented Aug 14, 2019

Your explanation is really very helpful. It is clear now in supervised learning, could you please mention what the possible changes in the code above to make it applicable for unlabeled data as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.