Skip to content

Instantly share code, notes, and snippets.

@interactivetech
Created November 30, 2017 17:14
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save interactivetech/86580f130ca47d893da7afb020a47e31 to your computer and use it in GitHub Desktop.
Save interactivetech/86580f130ca47d893da7afb020a47e31 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import edward as ed\n",
"from edward.models import Normal, Empirical,Categorical\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting MNIST_data\\train-images-idx3-ubyte.gz\n",
"Extracting MNIST_data\\train-labels-idx1-ubyte.gz\n",
"Extracting MNIST_data\\t10k-images-idx3-ubyte.gz\n",
"Extracting MNIST_data\\t10k-labels-idx1-ubyte.gz\n",
"[7 3 4 ..., 5 6 8]\n",
"\n",
"Image Shape: (28, 28, 1)\n",
"\n",
"Training Set: 55000 samples\n",
"Validation Set: 5000 samples\n",
"Test Set: 10000 samples\n"
]
}
],
"source": [
"from tensorflow.examples.tutorials.mnist import input_data\n",
"mnist = input_data.read_data_sets('MNIST_data', one_hot=True,reshape=False)\n",
"#mnist = input_data.read_data_sets(\"MNIST_data\", reshape=False)\n",
"X_train, y_train = mnist.train.images, mnist.train.labels\n",
"X_validation, y_validation = mnist.validation.images, mnist.validation.labels\n",
"X_test, y_test = mnist.test.images, mnist.test.labels\n",
"\n",
"\n",
"assert(len(X_train) == len(y_train))\n",
"assert(len(X_validation) == len(y_validation))\n",
"assert(len(X_test) == len(y_test))\n",
"\n",
"#y_train,y_test=tf.one_hot(y_train,10),tf.one_hot(y_test,10)\n",
"print(np.argmax(y_train,1))\n",
"print()\n",
"print(\"Image Shape: {}\".format(X_train[0].shape))\n",
"print()\n",
"print(\"Training Set: {} samples\".format(len(X_train)))\n",
"print(\"Validation Set: {} samples\".format(len(X_validation)))\n",
"print(\"Test Set: {} samples\".format(len(X_test)))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Updated Image Shape: (32, 32, 1)\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"# Pad images with 0s\n",
"X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')\n",
"X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant')\n",
"X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')\n",
" \n",
"print(\"Updated Image Shape: {}\".format(X_train[0].shape))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.utils import shuffle\n",
"X_train, y_train = shuffle(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"EPOCHS = 10\n",
"BATCH_SIZE = 128"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from tensorflow.contrib.layers import flatten\n",
"\n",
"def LeNet(xtrain,dic): \n",
" # Hyperparameters\n",
" mu = 0\n",
" sigma = 0.1\n",
" conv1_W=dic['conv1_W']\n",
" conv1_b=dic['conv1_b']\n",
" conv1 = tf.nn.conv2d(xtrain, conv1_W, strides=[1, 1, 1, 1], padding='VALID') + conv1_b\n",
"\n",
" # SOLUTION: Activation.\n",
" conv1 = tf.nn.relu(conv1)\n",
"\n",
" # SOLUTION: Pooling. Input = 28x28x6. Output = 14x14x6.\n",
" conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')\n",
" conv2_W=dic['conv2_W']\n",
" conv2_b=dic['conv2_b']\n",
" conv2 = tf.nn.conv2d(conv1, conv2_W, strides=[1, 1, 1, 1], padding='VALID') + conv2_b\n",
" \n",
" # SOLUTION: Activation.\n",
" conv2 = tf.nn.relu(conv2)\n",
"\n",
" # SOLUTION: Pooling. Input = 10x10x16. Output = 5x5x16.\n",
" conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')\n",
"\n",
" # SOLUTION: Flatten. Input = 5x5x16. Output = 400.\n",
" fc0 = flatten(conv2)\n",
" fc1_W=dic['fc1_W']\n",
" fc1_b=dic['fc1_b']\n",
" fc1 = tf.matmul(fc0, fc1_W) + fc1_b\n",
" \n",
" # SOLUTION: Activation.\n",
" fc1 = tf.nn.relu(fc1)\n",
"\n",
" # SOLUTION: Layer 4: Fully Connected. Input = 120. Output = 84.\n",
" fc2_W=dic['fc2_W']\n",
" fc2_b=dic['fc2_b']\n",
" fc2 = tf.matmul(fc1, fc2_W) + fc2_b\n",
" \n",
" # SOLUTION: Activation.\n",
" fc2 = tf.nn.relu(fc2)\n",
"\n",
" fc3_W=dic['fc3_W']\n",
" fc3_b=dic['fc3_b']\n",
" logits = tf.matmul(fc2, fc3_W) + fc3_b\n",
" #res = Categorical(logits)\n",
" \n",
" return logits"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"with tf.name_scope(\"model\"):\n",
" # SOLUTION: Layer 1: Convolutional. Input = 32x32x1. Output = 28x28x6.\n",
" conv1_W = Normal(loc=tf.ones([5, 5, 1, 6]),\\\n",
" scale=tf.ones([5, 5, 1, 6]),name='conv1_W')\n",
" conv1_b = Normal(loc=tf.ones(6),scale=tf.ones(6),name='conv1_b')\n",
" \n",
" # SOLUTION: Layer 2: Convolutional. Output = 10x10x16.\n",
" conv2_W = Normal(loc=tf.ones([5, 5, 6, 16]),\\\n",
" scale=tf.ones([5, 5, 6, 16]),name='conv2_W')\n",
" conv2_b = Normal(loc=tf.ones(16),scale=tf.ones(16),name='conv2_b')\n",
" \n",
" # SOLUTION: Layer 3: Fully Connected. Input = 400. Output = 120.\n",
" fc1_W = Normal(loc=tf.ones([400,120]),\\\n",
" scale=tf.ones([400,120]),name='fc1_W')\n",
" \n",
" fc1_b = Normal(loc=tf.ones(120),scale=tf.ones(120),name='fc1_b')\n",
" \n",
" fc2_W = Normal(loc=tf.ones([120,84]),\\\n",
" scale=tf.ones([120,84]),name='fc2_W')\n",
" \n",
" fc2_b = Normal(loc=tf.ones(84),scale=tf.ones(84),name='fc2_b')\n",
" \n",
" # SOLUTION: Layer 5: Fully Connected. Input = 84. Output = 10.\n",
" fc3_W = Normal(loc=tf.ones([84,10]),\\\n",
" scale=tf.ones([84,10]),name='fc3_W')\n",
" \n",
" fc3_b = Normal(loc=tf.ones(10),scale=tf.ones(10),name='fc3_b')\n",
" \n",
" X = tf.placeholder(tf.float32, (None, 32, 32, 1), name=\"X\")\n",
"\n",
" dic={'conv1_W':conv1_W,'conv1_b':conv1_b,'conv2_W':conv2_W,'conv2_b':conv2_b,\\\n",
" 'fc1_W':fc1_W,'fc1_b':fc1_b,'fc2_W':fc2_W,'fc2_b':fc2_b,'fc3_W':fc3_W,'fc3_b':fc3_b}\n",
" y=tf.identity(Categorical(LeNet(X,dic)),name=\"y\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"with tf.name_scope(\"posterior\"):\n",
" Nsamples=10000\n",
" with tf.name_scope(\"qconv1_W\"):\n",
" ''' \n",
" qconv1_W=Normal(loc=tf.Variable(tf.random_normal([5, 5, 1, 6])),\\\n",
" scale=tf.nn.softplus(tf.Variable(tf.random_normal([5, 5, 1, 6]))))\n",
" '''\n",
" qconv1_W= Empirical(params=tf.Variable(tf.random_normal([Nsamples,5,5,1,6])))\n",
" \n",
" with tf.name_scope(\"qconv1_b\"):\n",
" ''' \n",
" qconv1_b= Normal(loc=tf.Variable(tf.random_normal([6])),\\\n",
" scale=tf.nn.softplus(tf.Variable(tf.random_normal([6]))))\n",
" ''' \n",
" qconv1_b= Empirical(params=tf.Variable(tf.random_normal([Nsamples,6])))\n",
" \n",
" with tf.name_scope(\"qconv2_W\"): \n",
" '''\n",
" qconv2_W= Normal(loc=tf.Variable(tf.random_normal([5, 5, 6, 16])),\\\n",
" scale=tf.nn.softplus(tf.Variable(tf.random_normal([5, 5, 6, 16]))))\n",
" '''\n",
" qconv2_W= Empirical(params=tf.Variable(tf.random_normal([Nsamples,5,5,6,16]))) \n",
" \n",
" with tf.name_scope(\"qconv2_b\"):\n",
" ''' \n",
" qconv2_b=Normal(loc=tf.Variable(tf.random_normal([16])),\\\n",
" scale=tf.nn.softplus(tf.Variable(tf.random_normal([16]))))\n",
" '''\n",
" qconv2_b= Empirical(params=tf.Variable(tf.random_normal([Nsamples,16]))) \n",
" \n",
" with tf.name_scope(\"qfc1_W\"):\n",
" ''' \n",
" qfc1_W= Normal(loc=tf.Variable(tf.random_normal([400,120])),\\\n",
" scale=tf.nn.softplus(tf.Variable(tf.random_normal([400,120]))))\n",
" '''\n",
" qfc1_W= Empirical(params=tf.Variable(tf.random_normal([Nsamples,400,120])))\n",
" \n",
" with tf.name_scope(\"qfc1_b\"):\n",
" '''\n",
" qfc1_b=Normal(loc=tf.Variable(tf.random_normal([120])),\\\n",
" scale=tf.nn.softplus(tf.Variable(tf.random_normal([120]))))\n",
" '''\n",
" qfc1_b= Empirical(params=tf.Variable(tf.random_normal([Nsamples,120])))\n",
" \n",
" with tf.name_scope(\"qfc2_W\"):\n",
" '''\n",
" qfc2_W=Normal(loc=tf.Variable(tf.random_normal([120,84])),\\\n",
" scale=tf.nn.softplus(tf.Variable(tf.random_normal([120,84])))) \n",
" '''\n",
" qfc2_W= Empirical(params=tf.Variable(tf.random_normal([Nsamples,120,84])))\n",
" \n",
" with tf.name_scope(\"qfc2_b\"):\n",
" '''\n",
" qfc2_b = Normal(loc=tf.Variable(tf.random_normal([84])),\\\n",
" scale=tf.nn.softplus(tf.Variable(tf.random_normal([84]))))\n",
" '''\n",
" qfc2_b= Empirical(params=tf.Variable(tf.random_normal([Nsamples,84])))\n",
" \n",
" with tf.name_scope(\"qfc3_W\"): \n",
" '''\n",
" qfc3_W = Normal(loc=tf.Variable(tf.random_normal([84,10])),\\\n",
" scale=tf.nn.softplus(tf.Variable(tf.random_normal([84,10]))))\n",
" '''\n",
" qfc3_W = Empirical(params=tf.Variable(tf.random_normal([Nsamples,84,10])))\n",
" \n",
" \n",
" with tf.name_scope(\"qfc3_b\"): \n",
" '''\n",
" qfc3_b =Normal(loc=tf.Variable(tf.random_normal([10])),\\\n",
" scale=tf.nn.softplus(tf.Variable(tf.random_normal([10]))))\n",
" '''\n",
" qfc3_b=Empirical(params=tf.Variable(tf.random_normal([Nsamples,10])))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"N=100\n",
"x = tf.placeholder(tf.float32, (None, 32, 32, 1))\n",
"y_ph = tf.placeholder(tf.int32, [None])\n",
"\n",
"inference = ed.SGHMC({conv1_W:qconv1_W,conv1_b:qconv1_b,\n",
" conv2_W:qconv2_W,conv2_b:qconv2_b,\n",
" fc1_W:qfc1_W,fc1_b:qfc1_b,fc2_W:qfc2_W,fc2_b:qfc2_b,\n",
" fc3_W:qfc3_W,fc3_b:qfc3_b\n",
" },data={y:y_ph})"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"inference.initialize(n_iter=1000, n_print=100,step_size=1e-5)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# We will use an interactive session.\n",
"sess = tf.InteractiveSession()\n",
"# Initialise all the vairables in the session.\n",
"tf.global_variables_initializer().run()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training...\n",
"10000\n",
"10000/10000 [100%] ██████████████████████████████ Elapsed: 107s | Acceptance Rate: 1.000\n"
]
}
],
"source": [
"with tf.Session() as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
" num_examples = len(X_train)\n",
" \n",
" print(\"Training...\")\n",
" print(inference.n_iter)\n",
" for i in range(inference.n_iter):\n",
" #X_train, y_train = shuffle(X_train, y_train)\n",
" y_train_new = np.argmax(y_train, axis=1)\n",
" '''\n",
" for offset in range(0, num_examples, BATCH_SIZE):\n",
" end = offset + BATCH_SIZE\n",
" batch_x, batch_y = X_train[offset:end], y_train[offset:end]\n",
" info_dict=inference.update(feed_dict={x: batch_x, y_ph: batch_y})\n",
" inference.print_progress(info_dict)\n",
" '''\n",
" \n",
" info_dict=inference.update(feed_dict={x: X_train, y_ph: y_train_new})\n",
" inference.print_progress(info_dict)\n",
" #validation_accuracy = evaluate(X_validation, y_validation)\n",
" #print(\"EPOCH {} ...\".format(i+1))\n",
" #print(\"Validation Accuracy = {:.3f}\".format(validation_accuracy))\n",
" #print()\n",
" \n",
" #saver.save(sess, './lenet')\n",
" #print(\"Model saved\")\n",
" saver = tf.train.Saver()\n",
" saver.save(sess, './bayesianlenet')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[8 2 2 ..., 2 2 2]\n",
"1051\n",
"[0 0 0 ..., 0 0 0]\n",
"854\n",
"[5 6 8 ..., 5 6 1]\n",
"1645\n",
"[4 9 4 ..., 4 4 4]\n",
"1388\n",
"[7 4 4 ..., 4 7 1]\n",
"974\n",
"[9 9 9 ..., 9 9 5]\n",
"494\n",
"[7 4 4 ..., 4 9 9]\n",
"952\n",
"[3 3 3 ..., 2 3 3]\n",
"974\n",
"[8 8 7 ..., 8 8 8]\n",
"1833\n",
"[7 9 8 ..., 7 7 9]\n",
"983\n",
"[4 4 4 ..., 4 4 4]\n",
"649\n",
"[5 5 2 ..., 5 2 2]\n",
"1001\n",
"[9 3 9 ..., 2 9 3]\n",
"1093\n",
"[0 0 0 ..., 0 0 4]\n",
"1430\n",
"[2 2 2 ..., 2 2 2]\n",
"944\n",
"[2 5 2 ..., 2 2 2]\n",
"1049\n",
"[0 9 0 ..., 2 2 0]\n",
"919\n",
"[6 6 6 ..., 6 6 6]\n",
"995\n",
"[5 6 9 ..., 6 6 9]\n",
"969\n",
"[3 3 3 ..., 3 2 2]\n",
"494\n",
"[6 8 8 ..., 8 4 8]\n",
"1214\n",
"[2 2 2 ..., 2 2 2]\n",
"973\n",
"[4 7 8 ..., 4 1 4]\n",
"987\n",
"[7 7 7 ..., 1 7 7]\n",
"1110\n",
"[5 5 5 ..., 5 5 5]\n",
"1159\n",
"[0 3 3 ..., 3 3 3]\n",
"908\n",
"[8 4 5 ..., 5 4 4]\n",
"1298\n",
"[6 6 6 ..., 6 6 2]\n",
"1150\n",
"[1 1 1 ..., 1 1 1]\n",
"968\n",
"[3 3 9 ..., 3 3 4]\n",
"1027\n",
"[4 4 4 ..., 4 4 2]\n",
"979\n",
"[5 5 5 ..., 5 5 5]\n",
"986\n",
"[4 4 4 ..., 7 4 7]\n",
"948\n",
"[3 3 3 ..., 9 3 7]\n",
"1392\n",
"[1 3 9 ..., 2 3 6]\n",
"446\n",
"[6 2 7 ..., 6 6 6]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-14-34e4a4f27141>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 17\u001b[0m \u001b[0mypred\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mCategorical\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mLeNet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_test\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mdic\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mypred\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0meval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 19\u001b[1;33m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mypred\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0meval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m==\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_test\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;32mC:\\Users\\syale\\Anaconda3\\lib\\site-packages\\edward\\models\\random_variable.py\u001b[0m in \u001b[0;36meval\u001b[1;34m(self, session, feed_dict)\u001b[0m\n\u001b[0;32m 214\u001b[0m \u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 215\u001b[0m \"\"\"\n\u001b[1;32m--> 216\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0meval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 217\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 218\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mC:\\Users\\syale\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\ops.py\u001b[0m in \u001b[0;36meval\u001b[1;34m(self, feed_dict, session)\u001b[0m\n\u001b[0;32m 568\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 569\u001b[0m \"\"\"\n\u001b[1;32m--> 570\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0m_eval_using_default_session\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgraph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msession\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 571\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 572\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_dup\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mC:\\Users\\syale\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\ops.py\u001b[0m in \u001b[0;36m_eval_using_default_session\u001b[1;34m(tensors, feed_dict, graph, session)\u001b[0m\n\u001b[0;32m 4453\u001b[0m \u001b[1;34m\"the tensor's graph is different from the session's \"\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4454\u001b[0m \"graph.\")\n\u001b[1;32m-> 4455\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0msession\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 4456\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4457\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mC:\\Users\\syale\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 887\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 888\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[1;32m--> 889\u001b[1;33m run_metadata_ptr)\n\u001b[0m\u001b[0;32m 890\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 891\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mC:\\Users\\syale\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run\u001b[1;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 1118\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1119\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[1;32m-> 1120\u001b[1;33m feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[0;32m 1121\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1122\u001b[0m \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mC:\\Users\\syale\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_run\u001b[1;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 1315\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1316\u001b[0m return self._do_call(_run_fn, self._session, feeds, fetches, targets,\n\u001b[1;32m-> 1317\u001b[1;33m options, run_metadata)\n\u001b[0m\u001b[0;32m 1318\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1319\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mC:\\Users\\syale\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[1;34m(self, fn, *args)\u001b[0m\n\u001b[0;32m 1321\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1322\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1323\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1324\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1325\u001b[0m \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mC:\\Users\\syale\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[1;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[0;32m 1300\u001b[0m return tf_session.TF_Run(session, options,\n\u001b[0;32m 1301\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1302\u001b[1;33m status, run_metadata)\n\u001b[0m\u001b[0;32m 1303\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1304\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"with tf.name_scope(\"testing\"):\n",
" nTestSamples=100\n",
" result=[]\n",
" for i in range(nTestSamples):\n",
" conv1_W=qconv1_W.sample()\n",
" conv1_b=qconv1_b.sample()\n",
" conv2_W=qconv2_W.sample()\n",
" conv2_b=qconv2_b.sample()\n",
" fc1_W=qfc1_W.sample()\n",
" fc1_b=qfc1_b.sample()\n",
" fc2_W=qfc2_W.sample()\n",
" fc2_b=qfc2_b.sample()\n",
" fc3_W=qfc3_W.sample()\n",
" fc3_b=qfc3_b.sample()\n",
" dic={'conv1_W':conv1_W,'conv1_b':conv1_b,'conv2_W':conv2_W,'conv2_b':conv2_b,\n",
" 'fc1_W':fc1_W,'fc1_b':fc1_b,'fc2_W':fc2_W,'fc2_b':fc2_b,'fc3_W':fc3_W,'fc3_b':fc3_b}\n",
" ypred=Categorical(LeNet(X_test,dic))\n",
" print(ypred.eval())\n",
" print((ypred.eval()==np.argmax(y_test,1)).sum())"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD8CAYAAACSCdTiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADWdJREFUeJzt3W2MpfVdxvHvJUstpa1AmK5rAQcTUkNMCmQkKI3RAgZL\n0+UVaZPqmmA2TTShatNs9VVfmKyJaaqJMdlAdbUPSloqG1of1i2NMUHs0EdgaekDWHBhpyhCfdGW\n9ueLuTHjuuucx5kzP76fZHLux7mvzOz/2v/cc86ZVBWSpJ5+aLsDSJLmx5KXpMYseUlqzJKXpMYs\neUlqzJKXpMYseUlqzJKXpMYseUlqbNdWXuzCCy+s5eXlrbykJO14DzzwwLeqammSc7e05JeXl1ld\nXd3KS0rSjpfk8UnP9XaNJDVmyUtSY5a8JDVmyUtSY5a8JDVmyUtSY5a8JDVmyUtSY5a8JDW2pa94\nlTazfOAT23Ldxw7etC3XlebNmbwkNWbJS1JjlrwkNWbJS1JjlrwkNWbJS1JjlrwkNWbJS1Jjlrwk\nNWbJS1JjlrwkNWbJS1JjlrwkNWbJS1JjI73VcJLHgOeB7wMvVNVKkguAvwKWgceAW6rqP+YTU5I0\niXFm8r9QVVdU1cqwfgA4VlWXAceGdUnSApnmds1e4PCwfBi4efo4kqRZGrXkC/iHJA8k2T9s211V\nJ4blp4DdM08nSZrKqH/+7w1V9WSS1wBHkzyycWdVVZI63YnDfwr7AS655JKpwkqSxjPSTL6qnhwe\nTwIfB64Gnk6yB2B4PHmGcw9V1UpVrSwtLc0mtSRpJJuWfJJzk7zqxWXgF4EHgSPAvuGwfcDd8wop\nSZrMKLdrdgMfT/Li8R+uqr9N8hngziS3Ao8Dt8wvpiRpEpuWfFV9HXj9abY/A1w3j1CSpNnwFa+S\n1JglL0mNWfKS1JglL0mNWfKS1JglL0mNWfKS1JglL0mNWfKS1JglL0mNWfKS1JglL0mNWfKS1Jgl\nL0mNWfKS1JglL0mNWfKS1JglL0mNWfKS1JglL0mNWfKS1JglL0mNWfKS1JglL0mNWfKS1JglL0mN\nWfKS1JglL0mNWfKS1NjIJZ/krCSfS3LPsH5BkqNJHh0ez59fTEnSJMaZyd8GHN+wfgA4VlWXAceG\ndUnSAhmp5JNcBNwE3L5h817g8LB8GLh5ttEkSdMadSb/fuDdwA82bNtdVSeG5aeA3ac7Mcn+JKtJ\nVtfW1iZPKkka26Yln+TNwMmqeuBMx1RVAXWGfYeqaqWqVpaWliZPKkka264RjrkWeEuSNwEvB16d\n5IPA00n2VNWJJHuAk/MMKkka36Yz+ap6T1VdVFXLwFuBT1XV24EjwL7hsH3A3XNLKUmayDTPkz8I\n3JDkUeD6YV2StEBGuV3zP6rq08Cnh+VngOtmH0mSNCu+4lWSGrPkJakxS16SGrPkJakxS16SGrPk\nJakxS16SGrPkJakxS16SGrPkJakxS16SGrPkJakxS16SGrPkJakxS16SGrPkJakxS16SGrPkJakx\nS16SGrPkJakxS16SGrPkJakxS16SGrPkJakxS16SGrPkJakxS16SGrPkJamxTUs+ycuT/EuSLyR5\nKMl7h+0XJDma5NHh8fz5x5UkjWOUmfx3gDdW1euBK4Abk1wDHACOVdVlwLFhXZK0QDYt+Vr37WH1\n7OGjgL3A4WH7YeDmuSSUJE1spHvySc5K8nngJHC0qu4HdlfVieGQp4Ddc8ooSZrQSCVfVd+vqiuA\ni4Crk/zUKfuL9dn9/5Fkf5LVJKtra2tTB5YkjW6sZ9dU1bPAvcCNwNNJ9gAMjyfPcM6hqlqpqpWl\npaVp80qSxjDKs2uWkpw3LJ8D3AA8AhwB9g2H7QPunldISdJkdo1wzB7gcJKzWP9P4c6quifJfcCd\nSW4FHgdumWNOSdIENi35qvoicOVptj8DXDePUJKk2fAVr5LUmCUvSY1Z8pLUmCUvSY1Z8pLUmCUv\nSY1Z8pLUmCUvSY1Z8pLUmCUvSY1Z8pLUmCUvSY1Z8pLUmCUvSY1Z8pLUmCUvSY1Z8pLUmCUvSY1Z\n8pLUmCUvSY1Z8pLUmCUvSY1Z8pLUmCUvSY1Z8pLUmCUvSY1Z8pLUmCUvSY1Z8pLU2KYln+TiJPcm\neTjJQ0luG7ZfkORokkeHx/PnH1eSNI5RZvIvAL9dVZcD1wC/nuRy4ABwrKouA44N65KkBbJpyVfV\niar67LD8PHAceC2wFzg8HHYYuHleISVJkxnrnnySZeBK4H5gd1WdGHY9Bew+wzn7k6wmWV1bW5si\nqiRpXCOXfJJXAh8D3llVz23cV1UF1OnOq6pDVbVSVStLS0tThZUkjWekkk9yNusF/6GqumvY/HSS\nPcP+PcDJ+USUJE1qlGfXBLgDOF5V79uw6wiwb1jeB9w9+3iSpGnsGuGYa4FfBr6U5PPDtt8BDgJ3\nJrkVeBy4ZT4RJUmT2rTkq+qfgJxh93WzjSNJmiVf8SpJjVnyktSYJS9JjVnyktSYJS9JjVnyktSY\nJS9JjVnyktSYJS9JjVnyktSYJS9JjVnyktSYJS9JjVnyktSYJS9JjVnyktSYJS9JjVnyktSYJS9J\njVnyktSYJS9JjVnyktSYJS9JjVnyktSYJS9JjVnyktSYJS9JjVnyktTYpiWf5ANJTiZ5cMO2C5Ic\nTfLo8Hj+fGNKkiYxykz+z4AbT9l2ADhWVZcBx4Z1SdKC2bTkq+ofgX8/ZfNe4PCwfBi4eca5JEkz\nMOk9+d1VdWJYfgrYPaM8kqQZmvoXr1VVQJ1pf5L9SVaTrK6trU17OUnSGCYt+aeT7AEYHk+e6cCq\nOlRVK1W1srS0NOHlJEmTmLTkjwD7huV9wN2ziSNJmqVRnkL5EeA+4HVJnkhyK3AQuCHJo8D1w7ok\nacHs2uyAqnrbGXZdN+Ms0rZZPvCJbbv2Ywdv2rZrqz9f8SpJjVnyktSYJS9JjW16T17bZzvvE0vq\nwZm8JDVmyUtSY96ukbbZdt2W86mbLw3O5CWpMUtekhqz5CWpMe/JS9py/h5i6ziTl6TGLHlJasyS\nl6TGLHlJasySl6TGLHlJasySl6TGLHlJasySl6TGLHlJasySl6TGfO8a6SXKPy/50uBMXpIas+Ql\nqTFv14zAH2sl7VTO5CWpMUtekhqz5CWpsanuySe5EfhD4Czg9qo6OJNUp+F9cUnT2s4e2a4/PTjx\nTD7JWcAfA78EXA68LcnlswomSZreNLdrrga+WlVfr6rvAn8J7J1NLEnSLExT8q8Fvrlh/YlhmyRp\nQcz9efJJ9gP7h9VvJ/nyvK85oguBb213iBHslJywc7LulJywc7LulJywTVnz+2OfsjHnj0963WlK\n/kng4g3rFw3b/peqOgQcmuI6c5FktapWtjvHZnZKTtg5WXdKTtg5WXdKTtg5WWeVc5rbNZ8BLkty\naZKXAW8FjkwbSJI0OxPP5KvqhSS/Afwd60+h/EBVPTSzZJKkqU11T76qPgl8ckZZttrC3UI6g52S\nE3ZO1p2SE3ZO1p2SE3ZO1pnkTFXN4vNIkhaQb2sgSY21K/kkNyb5cpKvJjlwmv1J8kfD/i8mueqU\n/Wcl+VySexY5a5Lzknw0ySNJjif5mQXN+ZtJHkryYJKPJHn5vHKOmPUnk9yX5DtJ3jXOuYuQM8nF\nSe5N8vDwdb1tnjmnybph/5aMqSm/91s2nmaQdbwxVVVtPlj/BfDXgJ8AXgZ8Abj8lGPeBPwNEOAa\n4P5T9v8W8GHgnkXOChwGfm1Yfhlw3qLlZP3Fcd8AzhnW7wR+dZu/pq8Bfhr4PeBd45y7IDn3AFcN\ny68CvjKvnNNm3bB/7mNq2pxbNZ5m8P0fe0x1m8mP8lYLe4E/r3X/DJyXZA9AkouAm4DbFzlrkh8B\nfg64A6CqvltVzy5azmHfLuCcJLuAVwD/NqecI2WtqpNV9Rnge+Oeuwg5q+pEVX12WH4eOM58X2k+\nzdd0K8fUxDm3eDxNlXUw1pjqVvKjvNXC/3fM+4F3Az+YV8ARc2x2zKXAGvCnw4/Btyc5d9FyVtWT\nwB8A/wqcAP6zqv5+TjlHzTqPc8c1k2slWQauBO6fSarTmzbrVo2paXJu5XiCKbJOMqa6lfzEkrwZ\nOFlVD2x3lhHsAq4C/qSqrgT+C5jrPeRJJDmf9RnKpcCPAecmefv2puohySuBjwHvrKrntjvP6eyg\nMbUjxhNMNqa6lfwob7VwpmOuBd6S5DHWf3x6Y5IPzi/qVFmfAJ6oqhdncB9l/R/pouW8HvhGVa1V\n1feAu4CfnVPOUbPO49xxTXWtJGezXvAfqqq7ZpztVNNk3coxNU3OrRxPMF3WscdUt5If5a0WjgC/\nMjwj5BrWf9w5UVXvqaqLqmp5OO9TVTXPWec0WZ8CvpnkdcNx1wEPL1pO1n+kvCbJK5JkyHl8TjlH\nzTqPc7cs5/B1vAM4XlXvm1O+jSbOusVjapqcWzmeYLp/a+OPqXn9Bnm7Plh/psdXWP/t9e8O294B\nvGNYDut/7ORrwJeAldN8jp9nzs+umTYrcAWwCnwR+Gvg/AXN+V7gEeBB4C+AH97mr+mPsj5zew54\ndlh+9ZnOXbScwBuAGr7vnx8+3rSIWU/5HHMfU1N+77dsPM0g61hjyle8SlJj3W7XSJI2sOQlqTFL\nXpIas+QlqTFLXpIas+QlqTFLXpIas+QlqbH/BoauZuHNcsD2AAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x16f125a4fd0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.hist(np.array(result)/10000)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment