Skip to content

Instantly share code, notes, and snippets.

@divamgupta
Created May 31, 2019 16:23
Show Gist options
  • Star 13 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save divamgupta/c778c17459c1f162e789560d5e0b2f0b to your computer and use it in GitHub Desktop.
Save divamgupta/c778c17459c1f162e789560d5e0b2f0b to your computer and use it in GitHub Desktop.
Simple keras implementation of Virtual Adversarial Training .
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Couldn't import dot_parser, loading of dot files will not be possible.\n",
"/usr/local/lib/python2.7/dist-packages/cryptography/hazmat/primitives/constant_time.py:26: CryptographyDeprecationWarning: Support for your Python version is deprecated. The next version of cryptography will remove support. Please upgrade to a 2.7.x release that supports hmac.compare_digest as soon as possible.\n",
" utils.DeprecatedIn23,\n"
]
}
],
"source": [
"from numpy.random import seed\n",
"seed(0)\n",
"from tensorflow import set_random_seed\n",
"set_random_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import matplotlib\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import keras\n",
"from keras.models import * \n",
"from keras.layers import *\n",
"from sklearn.metrics import accuracy_score\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Make the datasets"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import datasets\n",
"\n",
"circles = datasets.make_circles(n_samples=1000 , noise=.05 , factor=0.3 ,random_state=3 )\n",
"circles_test = datasets.make_circles(n_samples=10000 , noise=0 , factor=0.3 ,random_state=1 )\n",
"\n",
"n_poionts = 8\n",
"inds = list (np.where(circles[1] == 0)[0][:n_poionts]) + list (np.where(circles[1] == 1)[0][:n_poionts])\n",
"\n",
"X_train = circles[0][inds]\n",
"Y_train = circles[1][inds]\n",
"Y_train_cat = keras.utils.to_categorical( circles[1][inds] )\n",
"\n",
"X_test = circles_test[0] \n",
"Y_test = circles_test[1] \n",
"Y_test_cat = keras.utils.to_categorical( circles_test[1] )\n",
"\n",
"n_classes = int( np.max(Y_train) + 1 )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot the dataset"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fb737078bd0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(X_test[:, 0], X_test[:, 1], c=Y_test, s=20 , cmap='winter' , edgecolor='none' , alpha=0.005)\n",
"plt.scatter(X_train[:, 0], X_train[:, 1], c=Y_train, s=20 , cmap='winter' , edgecolor='k')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def plot_model_predictions( m ):\n",
" \n",
" xx, yy = np.meshgrid(np.arange(-1.4, 1.4, 0.1),\n",
" np.arange(-1.8, 1.4, 0.1))\n",
"\n",
" Z = m.predict(np.c_[xx.ravel(), yy.ravel()]).argmax(-1)\n",
" Z = Z.reshape(xx.shape)\n",
"\n",
" plt.contourf(xx, yy, Z, alpha=0.3, cmap='Greens' )\n",
" plt.scatter(X_test[:, 0], X_test[:, 1], c=Y_test, s=20 , cmap='winter' , edgecolor='none' , alpha=0.005)\n",
" plt.scatter(X_train[:, 0], X_train[:, 1], c=Y_train, s=20 , cmap='winter' , edgecolor='k')\n",
" \n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model without VAT"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"model = Sequential()\n",
"model.add( Dense(100 ,activation='relu' , input_shape=(2,)))\n",
"model.add( Dense(2 , activation='softmax' ))\n",
"model.compile( 'sgd' , 'categorical_crossentropy' , metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/1\n",
"160000/160000 [==============================] - 13s 83us/step - loss: 0.2050 - acc: 0.9870\n",
"('Test accruracy ', 0.9059)\n"
]
}
],
"source": [
"model.fit( np.concatenate([X_train]*10000) , np.concatenate([Y_train_cat]*10000) )\n",
"\n",
"y_pred = model.predict( X_test ).argmax(-1)\n",
"print(\"Test accruracy \" , accuracy_score(Y_test , y_pred ))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot the model outputs"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python2.7/dist-packages/numpy/ma/core.py:6442: MaskedArrayFutureWarning: In the future the default for ma.minimum.reduce will be axis=0, not the current None, to match np.minimum.reduce. Explicitly pass 0 or None to silence this warning.\n",
" return self.reduce(a)\n",
"/usr/local/lib/python2.7/dist-packages/numpy/ma/core.py:6442: MaskedArrayFutureWarning: In the future the default for ma.maximum.reduce will be axis=0, not the current None, to match np.maximum.reduce. Explicitly pass 0 or None to silence this warning.\n",
" return self.reduce(a)\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fb7392b4290>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_model_predictions( model )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model with VAT"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def compute_kld(p_logit, q_logit):\n",
" p = tf.nn.softmax(p_logit)\n",
" q = tf.nn.softmax(q_logit)\n",
" return tf.reduce_sum(p*(tf.log(p + 1e-16) - tf.log(q + 1e-16)), axis=1)\n",
"\n",
"\n",
"def make_unit_norm(x):\n",
" return x/(tf.reshape(tf.sqrt(tf.reduce_sum(tf.pow(x, 2.0), axis=1)), [-1, 1]) + 1e-16)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"network = Sequential()\n",
"network.add( Dense(100 ,activation='relu' , input_shape=(2,)))\n",
"network.add( Dense(2 ))\n",
"\n",
"model_input = Input((2,))\n",
"p_logit = network( model_input )\n",
"p = Activation('softmax')( p_logit )\n",
"\n",
"r = tf.random_normal(shape=tf.shape( model_input ))\n",
"r = make_unit_norm( r )\n",
"p_logit_r = network( model_input + 10*r )\n",
"\n",
"kl = tf.reduce_mean(compute_kld( p_logit , p_logit_r ))\n",
"grad_kl = tf.gradients( kl , [r ])[0]\n",
"r_vadv = tf.stop_gradient(grad_kl)\n",
"r_vadv = make_unit_norm( r_vadv )/3.0\n",
"\n",
"\n",
"p_logit_no_gradient = tf.stop_gradient(p_logit)\n",
"p_logit_r_adv = network( model_input + r_vadv )\n",
"vat_loss = tf.reduce_mean(compute_kld( p_logit_no_gradient, p_logit_r_adv ))\n",
"\n",
"\n",
"model_vat = Model(model_input , p )\n",
"model_vat.add_loss( vat_loss )\n",
"\n",
"model_vat.compile( 'sgd' , 'categorical_crossentropy' , metrics=['accuracy'])\n",
"\n",
"model_vat.metrics_names.append('vat_loss')\n",
"model_vat.metrics_tensors.append( vat_loss )\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/1\n",
"160000/160000 [==============================] - 22s 137us/step - loss: 0.3167 - acc: 0.9779 - vat_loss: 0.0746\n",
"('Test accruracy ', 1.0)\n"
]
}
],
"source": [
"model_vat.fit( np.concatenate([X_train]*10000) , np.concatenate([Y_train_cat]*10000) )\n",
"\n",
"y_pred = model_vat.predict( X_test ).argmax(-1)\n",
"print( \"Test accruracy \" , accuracy_score(Y_test , y_pred ))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot the model outputs"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fb7362adc10>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_model_predictions( model_vat )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@fahad7033
Copy link

Your explanation is really very helpful. It is clear now in supervised learning, could you please mention what the possible changes in the code above to make it applicable for unlabeled data as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment