Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Facial Emotion Recognition - WSL
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dataset Generation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read and process images into matrices"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"'/home/psh/OneDrive/DeepLearning/advanced-tensorflow/basic'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"os.getcwd()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['../../img_dataset/face_emotion/1_disgust',\n",
" '../../img_dataset/face_emotion/2_fear',\n",
" '../../img_dataset/face_emotion/4_sad',\n",
" '../../img_dataset/face_emotion/0_angry',\n",
" '../../img_dataset/face_emotion/5_surprise',\n",
" '../../img_dataset/face_emotion/3_happy',\n",
" '../../img_dataset/face_emotion/6_neutral']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"base_path = \"../../img_dataset/face_emotion/\"\n",
"paths = map(lambda x: os.path.join(base_path, x), os.listdir(base_path))\n",
"if '../../img_dataset/face_emotion/.DS_Store' in paths:\n",
" paths.remove('../../img_dataset/face_emotion/.DS_Store')\n",
"paths"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['.tiff', '.jpg', '.jpeg', '.png', '.gif', '.tga', '.TIFF', '.JPG', '.JPEG', '.PNG', '.GIF', '.TGA']\n"
]
}
],
"source": [
"valid_exts = [\".tiff\", \".jpg\", \".jpeg\", \".png\", \".gif\", \".tga\"]\n",
"valid_exts += map(lambda x: x.upper(), valid_exts)\n",
"print valid_exts"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TOTAL IMAGES: 35886\n"
]
}
],
"source": [
"imgsize = (48, 48)\n",
"grayscale = False\n",
"\n",
"n_imgs = 0\n",
"one_hot_labels = np.eye(len(paths))\n",
"images = []\n",
"labels = []\n",
"\n",
"for i, path in enumerate(paths):\n",
" lbl = one_hot_labels[i]#.reshape(1, -1) # one-hot label for current folder \n",
" for fname in os.listdir(path):\n",
" whole_fname = os.path.join(path, fname)\n",
" if os.path.splitext(fname)[-1] not in valid_exts:\n",
" continue\n",
" \n",
" # load and process image\n",
" img = imread(whole_fname)\n",
" img = img.flatten()#reshape(1, -1)\n",
" \n",
" # save as numpy.array object\n",
" #if n_imgs == 0:\n",
" # images = img\n",
" # labels = lbl\n",
" #images = np.concatenate((images, img), axis=0)\n",
" #labels = np.concatenate((labels, lbl), axis=0)\n",
" images.append(img)\n",
" labels.append(lbl)\n",
" n_imgs += 1\n",
" \n",
"images = np.array(images)\n",
"labels = np.array(labels)\n",
"print \"TOTAL IMAGES:\", n_imgs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## train-test split"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"seed = np.random.randint(100, size=1)[0]\n",
"\n",
"np.random.seed(seed)\n",
"np.random.shuffle(images)\n",
"np.random.seed(seed)\n",
"np.random.shuffle(labels)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TRAIN SHAPE: (25120, 2304) (25120, 7)\n",
"TEST SHAPE: (5382, 2304) (5382, 7)\n",
"VAL SHAPE: (5384, 2304) (5384, 7)\n"
]
}
],
"source": [
"train_test_val_ratio = (.7, .15, .15)\n",
"train_idx, test_idx, _ = np.array(map(lambda x: int(x * n_imgs), train_test_val_ratio)).cumsum()\n",
"\n",
"X_train = images[:train_idx]\n",
"X_test = images[train_idx:test_idx]\n",
"X_val = images[test_idx:]\n",
"y_train = labels[:train_idx]\n",
"y_test = labels[train_idx:test_idx]\n",
"y_val = labels[test_idx:]\n",
"\n",
"print \"TRAIN SHAPE:\", X_train.shape, y_train.shape\n",
"print \"TEST SHAPE:\", X_test.shape, y_test.shape\n",
"print \"VAL SHAPE:\", X_val.shape, y_val.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save as .npz"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SAVED!\n"
]
}
],
"source": [
"categories = [os.path.basename(path) for path in paths]\n",
"dataname = \"face_emotion\"\n",
"savepath = os.path.join(os.getcwd(), \"data\", dataname+\".npz\")\n",
"\n",
"np.savez(\n",
" savepath, \n",
" xtrain=X_train, xtest=X_test, xval=X_val,\n",
" ytrain=y_train, ytest=y_test, yval=y_val,\n",
" imgsize=imgsize, categories=categories\n",
")\n",
"print \"SAVED!\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"# Using ConvNet to classify custom data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load custom data"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['ytest', 'imgsize', 'yval', 'ytrain', 'xtest', 'xtrain', 'xval', 'categories']\n"
]
}
],
"source": [
"fer = np.load(\"data/face_emotion.npz\")\n",
"print fer.files"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((25120, 2304), (25120, 7))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fer[\"xtrain\"].shape, fer[\"ytrain\"].shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## CNN implementation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Without TF-slim"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"img_w, img_h = (48, 48)\n",
"img_ch = 1\n",
"n_class = 7\n",
"\n",
"c1_ch = 64\n",
"c2_ch = 64\n",
"c3_ch = 64\n",
"d1_n = 1024\n",
"d2_n = 1024\n",
"\n",
"X = tf.placeholder(tf.float32, [None, img_w*img_h*img_ch])\n",
"y = tf.placeholder(tf.float32, [None, n_class])\n",
"_W = {\n",
" \"c1\": tf.Variable(tf.random_normal([3, 3, img_ch, c1_ch], stddev=.01)),\n",
" \"c2\": tf.Variable(tf.random_normal([3, 3, c1_ch, c2_ch], stddev=.01)),\n",
" \"c3\": tf.Variable(tf.random_normal([3, 3, c2_ch, c3_ch], stddev=.01)),\n",
" \"d1\": tf.Variable(tf.random_normal([img_w*img_h*c3_ch//64, d1_n], stddev=.01)),\n",
" \"d2\": tf.Variable(tf.random_normal([d1_n, d2_n], stddev=.01)),\n",
" \"d3\": tf.Variable(tf.random_normal([d2_n, n_class], stddev=.01))\n",
"}\n",
"_b = {\n",
" \"c1\": tf.Variable(tf.zeros([c1_ch])),\n",
" \"c2\": tf.Variable(tf.zeros([c2_ch])),\n",
" \"c3\": tf.Variable(tf.zeros([c3_ch])),\n",
" \"d1\": tf.Variable(tf.zeros([d1_n])),\n",
" \"d2\": tf.Variable(tf.zeros([d2_n])),\n",
" \"d3\": tf.Variable(tf.zeros([n_class]))\n",
"}\n",
"\n",
"def convnet(X, _W, _b):\n",
" input_r = tf.reshape(X, [-1, img_w, img_h, img_ch])\n",
" conv1 = tf.nn.conv2d(input_r, _W[\"c1\"], [1, 1, 1, 1], \"SAME\") + _b[\"c1\"]\n",
" relu1 = tf.nn.relu(conv1)\n",
" pool1 = tf.nn.max_pool(relu1, [1, 2, 2, 1], [1, 2, 2, 1], \"SAME\")\n",
" \n",
" conv2 = tf.nn.conv2d(pool1, _W[\"c2\"], [1, 1, 1, 1], \"SAME\") + _b[\"c2\"]\n",
" relu2 = tf.nn.relu(conv2)\n",
" pool2 = tf.nn.max_pool(relu2, [1, 2, 2, 1], [1, 2, 2, 1], \"SAME\")\n",
" \n",
" conv3 = tf.nn.conv2d(pool2, _W[\"c3\"], [1, 1, 1, 1], \"SAME\") + _b[\"c3\"]\n",
" relu3 = tf.nn.relu(conv3)\n",
" pool3 = tf.nn.max_pool(relu3, [1, 2, 2, 1], [1, 2, 2, 1], \"SAME\")\n",
"\n",
" dense = tf.reshape(pool3, [-1, img_w*img_h*c3_ch//64])\n",
" fc1 = tf.matmul(dense, _W[\"d1\"]) + _b[\"d1\"]\n",
" actv1 = tf.nn.relu(fc1)\n",
" fc2 = tf.matmul(actv1, _W[\"d2\"]) + _b[\"d2\"]\n",
" actv2 = tf.nn.relu(fc2)\n",
" fc3 = tf.matmul(actv2, _W[\"d3\"]) + _b[\"d3\"]\n",
" score = tf.nn.softmax(fc3)\n",
"\n",
" out = {\n",
" \"input_r\": input_r,\n",
" \"conv1\": conv1, \"relu1\": relu1, \"pool1\": pool1,\n",
" \"conv2\": conv2, \"relu2\": relu2, \"pool2\": pool2,\n",
" \"conv3\": conv3, \"relu3\": relu3, \"pool3\": pool3,\n",
" \"dense\": dense,\n",
" \"fc1\": fc1, \"actv1\": actv1,\n",
" \"fc2\": fc2, \"actv2\": actv2,\n",
" \"fc3\": fc3,\n",
" \"score\": score\n",
" }\n",
"\n",
" return out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### With TF-Slim"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def lrelu(x, leak=.2, scope=\"lrelu\"):\n",
" with tf.variable_scope(scope):\n",
" f1 = (1-leak) / 2.\n",
" f2 = (1+leak) / 2.\n",
" return f1*x + f2*tf.abs(x)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def slim_convnet(X, is_training):\n",
" init_w = tf.truncated_normal_initializer(stddev=.01)\n",
" batchnorm_params = {\"is_training\": is_training, \"decay\": 0.9, \"updates_collections\": None}\n",
" \n",
" X_r = tf.reshape(X, [-1, 48, 48, 1])\n",
" net = slim.conv2d(X_r, 128, [3,3], padding=\"SAME\",\n",
" normalizer_fn=slim.batch_norm,\n",
" normalizer_params=batchnorm_params,\n",
" activation_fn=lrelu,\n",
" weights_initializer=init_w,\n",
" scope=\"CONV1\")\n",
" net = slim.conv2d(net, 128, [3,3], padding=\"SAME\",\n",
" normalizer_fn=slim.batch_norm,\n",
" normalizer_params=batchnorm_params,\n",
" activation_fn=lrelu,\n",
" weights_initializer=init_w,\n",
" scope=\"CONV2\")\n",
" net = slim.max_pool2d(net, [2,2], scope=\"POOL1\")\n",
" \n",
" net = slim.conv2d(X_r, 128, [3,3], padding=\"SAME\",\n",
" normalizer_fn=slim.batch_norm,\n",
" normalizer_params=batchnorm_params,\n",
" activation_fn=lrelu,\n",
" weights_initializer=init_w,\n",
" scope=\"CONV3\")\n",
" net = slim.conv2d(net, 128, [3,3], padding=\"SAME\",\n",
" normalizer_fn=slim.batch_norm,\n",
" normalizer_params=batchnorm_params,\n",
" activation_fn=lrelu,\n",
" weights_initializer=init_w,\n",
" scope=\"CONV4\")\n",
" net = slim.max_pool2d(net, [2,2], scope=\"POOL2\")\n",
" \n",
" net = slim.conv2d(net, 128, [3,3], padding=\"SAME\",\n",
" normalizer_fn=slim.batch_norm,\n",
" normalizer_params=batchnorm_params,\n",
" activation_fn=lrelu,\n",
" weights_initializer=init_w,\n",
" scope=\"CONV5\")\n",
" net = slim.conv2d(net, 128, [3,3], padding=\"SAME\",\n",
" normalizer_fn=slim.batch_norm,\n",
" normalizer_params=batchnorm_params,\n",
" activation_fn=lrelu,\n",
" weights_initializer=init_w,\n",
" scope=\"CONV6\")\n",
" net = slim.max_pool2d(net, [2,2], scope=\"POOL3\")\n",
" \n",
" net = slim.flatten(net, \"FLAT\")\n",
" net = slim.fully_connected(net, 1024,\n",
" normalizer_fn=slim.batch_norm,\n",
" normalizer_params=batchnorm_params,\n",
" weights_initializer=init_w,\n",
" activation_fn=lrelu,\n",
" scope=\"DENSE1\")\n",
" net = slim.dropout(net, .7, is_training=is_training, scope=\"DROP1\")\n",
" net = slim.fully_connected(net, 1024,\n",
" normalizer_fn=slim.batch_norm,\n",
" normalizer_params=batchnorm_params,\n",
" weights_initializer=init_w,\n",
" activation_fn=lrelu,\n",
" scope=\"DENSE2\")\n",
" net = slim.dropout(net, .7, is_training=is_training, scope=\"DROP2\")\n",
" out = slim.fully_connected(net, 7,\n",
" weights_initializer=init_w,\n",
" activation_fn=tf.nn.softmax,\n",
" scope=\"LOGIT\")\n",
" \n",
" return out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loss function & optimizer & eval metrics\n",
"* Use weighted loss to compensate for label skewness"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"X = tf.placeholder(tf.float32, [None, 48*48*1])\n",
"y = tf.placeholder(tf.float32, [None, 7])\n",
"is_training = tf.placeholder(tf.bool)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"logit = slim_convnet(X, is_training)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 5.74696866, 7.0324748 , 3.99173685, 5.90225564, 8.93950178,\n",
" 7.31082654, 65.7591623 ])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_counts = np.unique(np.argmax(fer[\"ytrain\"], axis=1), return_counts=True)[1]\n",
"class_weights = (class_counts.sum() / class_counts.astype(float))\n",
"class_weights"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"learning_rate = .005\n",
"global_step = tf.Variable(0, trainable=False)\n",
"decay_steps = 50.\n",
"decay_rate = .93\n",
"learning_rate = tf.train.inverse_time_decay(learning_rate, global_step, \n",
" decay_steps, decay_rate, staircase=True)\n",
"\n",
"weights = tf.placeholder(tf.float32, [None])\n",
"loss = tf.losses.softmax_cross_entropy(y, logit, weights=weights)\n",
"#optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=.9, use_nesterov=True).minimize(loss)\n",
"optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=.0001).minimize(loss, global_step=global_step)\n",
"\n",
"pred = tf.argmax(logit, 1)\n",
"truth = tf.argmax(y, 1)\n",
"acc = tf.reduce_mean(tf.cast(tf.equal(truth, pred), tf.float32))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data augmentation"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from scipy import ndimage"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def augment_data(images):\n",
" images_r = images.reshape(-1, 48, 48)\n",
" images_aug = []\n",
" for img in images_r:\n",
" # FLIP\n",
" if np.random.randint(0, 2) == 1:\n",
" img_flipped = np.fliplr(img)\n",
" else:\n",
" img_flipped = img\n",
" # ROTATE\n",
" angle = np.random.randint(-10, 10)\n",
" img_rotated = ndimage.rotate(img_flipped, angle, reshape=False)\n",
" # scaled\n",
" scale = np.random.uniform(1., 1.1)\n",
" img_scaled = ndimage.zoom(img_rotated, scale)\n",
" center = img_scaled.shape[0] // 2\n",
" st = (center-24); en = (center+24)\n",
" img_scaled = img_scaled[st:en, st:en]\n",
" # SHIFT\n",
" shift = np.random.randint(-2, 2)\n",
" img_shifted = ndimage.shift(img_scaled, shift)\n",
" # add to the list\n",
" img_flattend = img_shifted.flatten()\n",
" images_aug.append(img_flattend)\n",
" images_aug = np.array(images_aug)\n",
" return images_aug"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/psh/.local/lib/python2.7/site-packages/scipy/ndimage/interpolation.py:583: UserWarning: From scipy 0.13.0, the output shape of zoom() is calculated with round() instead of int() - for these inputs the size of the returned array has changed.\n",
" \"the returned array has changed.\", UserWarning)\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fc188250b90>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"idx = 12\n",
"plt.subplot(121)\n",
"plt.imshow(fer[\"xtrain\"][idx].reshape(48, 48), cmap=mpl.cm.gray)\n",
"plt.subplot(122)\n",
"plt.imshow(augment_data(fer[\"xtrain\"][idx])[0].reshape(48, 48), cmap=mpl.cm.gray)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"xtrain = fer[\"xtrain\"]\n",
"ytrain = fer[\"ytrain\"]\n",
"xtest = fer[\"xtest\"]\n",
"ytest = fer[\"ytest\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Whiten inputs"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"train_mean = np.mean(xtrain, axis=0)\n",
"train_std = np.std(xtrain, axis=0)\n",
" \n",
"def whiten(x, train_mean=train_mean, train_std=train_std):\n",
" return (x - train_mean) / train_std"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training process"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/psh/.local/lib/python2.7/site-packages/matplotlib/axes/_axes.py:6462: UserWarning: The 'normed' kwarg is deprecated, and has been replaced by the 'density' kwarg.\n",
" warnings.warn(\"The 'normed' kwarg is deprecated, and has been \"\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.axes._subplots.AxesSubplot at 0x7fc188280250>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAD1FJREFUeJzt3W+MHdV5x/Gvu8amhZaSXFRFtgt246JAUeyydqloSRT8L0rk5YURjkTlVAiXCFetUJSapjKqo6qBSGnfOA0muCJpU8dxkmZVOXVRgLZR5GTX4EBt4mIcjNdKZTZ2QtMQOwu3L86xMlx2k3P33t3Zvc/3I412/pwz9xlZ/s3smbmzc5rNJpKkGH6h7gIkSdPH0JekQAx9SQrE0JekQAx9SQrE0JekQAx9SQrE0JekQAx9SQpkbt0FtHrppZeaJ06cqLsMSZpV+vv7R4Erfl67GRf6J06cYMWKFXWXIUmzSrPZLLpaLh3eWQccBY4BW8fZfhfwDHAI+BpwTWXbvbnfUWBt4edJkqZASej3ATuAd5PC/H28PtQBPgtcBywDHgA+ntdfA2wEriWdOD6R9ydJqkFJ6K8kXakfB84Du4GBljYvV+YvAS68unMgtz8HfCfvZ2UH9UqSOlAypr8AOFlZHgF+Z5x2dwP3APOAd1X6Hmjpu6D9MiVJ3dDNRzZ3AL8B/BnwF2323QwMA8ONRqOLJUmSqkpC/xSwqLK8MK+byG7gljb77gT6gf7R0dGCkiRJk1ES+kPAUmAxaehmIzDY0mZpZf49wHN5fjC3n5/7LwW+2UG9kqQOlIzpjwFbgP2kJ292AYeB7aQhmcG8fRXwE+AssCn3PQzsAY7k/dwNvNq98iVJ7Zgz0/5G7vDwcNMvZ0lSe5rN5kHSMPnPNOO+kauZ668/+aG6Syh2710P1F2CNCP5wjVJCsTQl6RADH1JCsTQl6RADH1JCsTQl6RADH1JCsTQl6RADH1JCsTQl6RADH1JCsTQl6RADH1JCsTQl6RADH1JCsTQl6RADH1JCsTQl6RADH1JCsTQl6RADH1JCsTQl6RADH1JCsTQl6RADH1JCqQ09NcBR4FjwNZxtt8DHAGeBr4KXFnZ9ipwKE+Dk65UktSxuQVt+oAdwGpgBBgihfeRSpungH7gR8AHgAeA2/K2V4BlXapXktSBkiv9laQr/OPAeWA3MNDS5nFS4AMcABZ2q0BJUveUhP4C4GRleSSvm8gdwFcqyxcDw6STwS3tFihJ6p6S4Z123E4a5nlHZd2VwClgCfAY8AzwfEu/zXmi0Wh0uSRJ0gUlV/qngEWV5YV5XatVwIeB9cC5lv6QhoeeAJaP03cn6WTRPzo6WlCSJGkySkJ/CFgKLAbmARt541M4y4EHSYF/urL+cmB+nm8AN/L6G8CSpGlUMrwzBmwB9pOe5NkFHAa2k8bqB4GPAZcCn899XiSdAN5GOhm8RjrBfBRDX5JqUzqmvy9PVdsq86sm6Pd14Lp2i5IkTQ2/kStJgRj6khSIoS9JgRj6khSIoS9JgRj6khSIoS9JgRj6khSIoS9JgRj6khSIoS9JgRj6khSIoS9JgRj6khSIoS9JgRj6khSIoS9JgRj6khSIoS9JgRj6khSIoS9JgRj6khSIoS9JgRj6khSIoS9JgRj6khSIoS9JgZSG/jrgKHAM2DrO9nuAI8DTwFeBKyvbNgHP5WnTpCuVJHVsbkGbPmAHsBoYAYaAQVLIX/AU0A/8CPgA8ABwG/Am4L68rQkczH3Pdqf8N7phw8BU7brrDuz9ct0lSAqm5Ep/JekK/zhwHtgNtCbr46TABzgALMzza4FHgTOkoH+U9FuDJKkGJaG/ADhZWR7J6yZyB/CVNvtuBoaB4UajUVCSJGkySoZ32nE7aSjnHW3225knRkdHm12uSQHdeefauktoy0MP7a+7BAVREvqngEWV5YV5XatVwIdJgX+u0vedLX2faLfIdgysunoqd99VB/bWXUHvWnL92+suoT2GvqZJyfDOELAUWAzMAzaSbsZWLQceBNYDpyvr9wNrgMvztCavkyTVoORKfwzYQgrrPmAXcBjYThqHHwQ+BlwKfD73eZF0AjgDfIR04iD3OdOl2iVJbSod09+Xp6ptlflVP6PvrjxJkmrmN3IlKRBDX5ICMfQlKRBDX5ICMfQlKRBDX5ICMfQlKRBDX5ICMfQlKRBDX5ICMfQlKRBDX5ICMfQlKRBDX5ICMfQlKRBDX5ICMfQlKRBDX5ICKf1ziZoCN2wYqLsEScEY+jUaWHV13SVIbbvzzrV1l9CWhx7aX3cJM4rDO5IUiKEvSYEY+pIUiGP6ktqy5Pq3111CexzTfx2v9CUpEENfkgIpDf11wFHgGLB1nO03AU8CY8CGlm2vAofyNDi5MiVJ3VAypt8H7ABWAyPAECm8j1TavAi8H/jgOP1fAZZ1VKUkqStKQn8l6Qr/eF7eDQzw+tB/If98rWuVSZK6rmR4ZwFwsrI8kteVuhgYBg4At7TRT5LUZdPxyOaVwClgCfAY8AzwfEubzXmi0WhMQ0mSFFPJlf4pYFFleWFeV+pC2+PAE8DycdrsBPqB/tHR0TZ2LUlqR0noDwFLgcXAPGAj5U/hXA7Mz/MN4EZefy9AkjSNSkJ/DNgC7AeeBfYAh4HtwPrcZgVprP9W4MG8HeBtpPH8bwGPAx/F0Jek2pSO6e/LU9W2yvwQadin1deB6yZRlyRpCviNXEkKxNCXpEAMfUkKxNCXpEAMfUkKxNCXpEAMfUkKxNCXpEAMfUkKxD+MLs0AN2wYqLsEBWHoSzPAwKqr6y5BQTi8I0mBGPqSFIihL0mBGPqSFIihL0mBGPqSFIihL0mBGPqSFIihL0mBGPqSFIihL0mBGPqSFIihL0mBGPqSFIihL0mBlIb+OuAocAzYOs72m4AngTFgQ8u2TcBzedo0uTIlSd1Q8kdU+oAdwGpgBBgCBoEjlTYvAu8HPtjS903AfUA/0AQO5r5nOylakjQ5JVf6K0lX+MeB88BuoPVvu70APA281rJ+LfAocIYU9I+SfmuQJNWgJPQXACcryyN5XYlO+kqSumym/I3czXmi0WjUXIok9a6SK/1TwKLK8sK8rkRp352kcf/+0dHRwl1LktpVEvpDwFJgMTAP2Ei6GVtiP7AGuDxPa/I6SVINSkJ/DNhCCutngT3AYWA7sD63WUEar78VeDBvh3QD9yOkE8dQ7nOmS7VLktpUOqa/L09V2yrzQ6Shm/HsypMkqWZ+I1eSAjH0JSkQQ1+SAjH0JSkQQ1+SAjH0JSkQQ1+SAjH0JSkQQ1+SAjH0JSkQQ1+SAjH0JSkQQ1+SAjH0JSkQQ1+SAjH0JSkQQ1+SAjH0JSkQQ1+SAjH0JSkQQ1+SAjH0JSkQQ1+SAjH0JSkQQ1+SAjH0JSkQQ1+SAikN/XXAUeAYsHWc7fOBz+Xt3wCuyuuvAl4BDuXpk5MvVZLUqbkFbfqAHcBqYAQYAgaBI5U2dwBngbcCG4H7gdvytueBZV2qV5LUgZIr/ZWkK/jjwHlgNzDQ0mYAeCTP7wVuBuZ0qUZJUpeUhP4C4GRleSSvm6jNGPAD4M15eTHwFPDvwO9P8BmbgWFguNFoFJQkSZqMkuGdTnwX+HXge8D1wD8D1wIvt7TbmSdGR0ebU1yTJIVVcqV/ClhUWV6Y103UZi5wGSnoz+WfAAdJ4/u/OdliJUmdKQn9IWApaZhmHulG7WBLm0FgU57fADwGNIErSDeCAZbk/RzvrGRJ0mSVDO+MAVuA/aQA3wUcBraTxuEHgYeBz5Bu+J4hnRgAbsrtfgK8BtyVt0uSalA6pr8vT1XbKvM/Bm4dp98X8iRJmgH8Rq4kBWLoS1Ighr4kBWLoS1Ighr4kBWLoS1Ighr4kBWLoS1Ighr4kBWLoS1Ighr4kBWLoS1Ighr4kBWLoS1Ighr4kBWLoS1Ighr4kBVL6l7MkaVa6YcNA3SUUO7D3y1P+GYa+pJ42sOrquksodmDv1H+GwzuSFIihL0mBGPqSFIihL0mBGPqSFIihL0mBGPqSFEhp6K8DjgLHgK3jbJ8PfC5v/wZwVWXbvXn9UWDtZAuVJHWuJPT7gB3Au4FrgPfln1V3AGeBtwJ/A9yf118DbASuJZ04PpH3J0mqQUnoryRdqR8HzgO7gdbvNQ8Aj+T5vcDNwJy8fjdwDvhO3s/KjquWJE1KSegvAE5WlkfyuonajAE/AN5c2FeSNE1myrt3NueJ/v7+HzabzaMd7KsBjHalqnr1ynGAxzJT9cqx9MpxsPWP7u/kWK4saVQS+qeARZXlhXndeG1G8j4vA75X2BdgZ566YRjo79K+6tQrxwEey0zVK8fSK8cB03AsJcM7Q8BSYDEwj3RjdrClzSCwKc9vAB4Dmnn9RtLTPYvzfr7ZcdWSpEkpudIfA7YA+0lP3uwCDgPbSWelQeBh4DOkG7VnSEFPbrcHOJL3czfwavfKlyS1o3RMf1+eqrZV5n8M3DpB37/K03Tp1jBR3XrlOMBjmal65Vh65ThgGo5lTrPZnOrPkCTNEL6GQZIC6aXQ/3mvipgtdgGngf+qu5AuWAQ8Trqncxj4k3rLmbSLSQ8gfIt0HH9Zbzld0Qc8BfxL3YV06AXgGeAQ6R7jbParpC+3fht4FvjdqfiQXhne6QP+G1hNemx0iPS6iCN1FjVJNwE/BD4N/FbNtXTqLXl6Evhl4CBwC7Pv32UOcAnp3+Ui4GukE9iBOovq0D2kRwN/BXhvzbV04gXScfTCc/qPAP8JfIr0pOQvAd/v9of0ypV+yasiZov/ID0B1Qu+Swp8gP8lXb3Mxm9kN0mBDyn0L8rrZquFwHtI4aKZ4TLSBd/Defk8UxD40Duh7+seZr6rgOWkt7DORn2kIYTTwKPM3uMA+FvgQ8BrdRfSBU3g30i/RW6uuZZOLAZeAv6eNOz2KdJvl13XK6Gvme1S4AvAnwIv11zLZL0KLCNdJa9k9g69vZd04jpYdyFd8nvAb5PeAnw36Wp5NppLOo6/I10c/R9TdG+yV0K/9HUPmn4XkQL/H4Ev1lxLN3yfdHN6Xd2FTNKNwHrSWPhu4F3AP9RZUIcu/D8/DXyJ2fsW35E8XfgNci/pJNB1vRL6Ja+K0PSbQxqjfBb4eM21dOIK0pMVAL9IemDg2/WV05F7SRdFV5H+nzwG3F5nQR24hPSAwIX5Nczep97+hzREfXVevpkpeuBhprxls1MTvSpiNvon4J2kNweOAPfx05s7s82NwB/w00fqAP6cN367e6Z7C+nJij7ShdIeZv+jjr3g10hX95Cy7LPAv9ZXTsf+mPQb8TzSQyl/OBUf0iuPbEqSCvTK8I4kqYChL0mBGPqSFIihL0mBGPqSFIihL0mBGPqSFIihL0mB/D93gZNSAEZZ5QAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sns.distplot(np.argmax(ytrain, 1), bins=7, norm_hist=True, kde=False)\n",
"sns.distplot(np.argmax(fer[\"ytest\"], 1), bins=7, norm_hist=True, kde=False)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"sess = tf.Session()\n",
"init = tf.global_variables_initializer()\n",
"sess.run(init)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"saver = tf.train.Saver(max_to_keep=10)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"005th: loss 11.98410, train_acc 0.522, test_acc 0.541\n",
"010th: loss 11.60389, train_acc 0.588, test_acc 0.559\n",
"015th: loss 11.37983, train_acc 0.625, test_acc 0.580\n",
"020th: loss 11.23594, train_acc 0.649, test_acc 0.586\n",
"025th: loss 11.12284, train_acc 0.668, test_acc 0.576\n",
"030th: loss 11.04105, train_acc 0.681, test_acc 0.584\n",
"035th: loss 10.95620, train_acc 0.696, test_acc 0.588\n",
"040th: loss 10.89005, train_acc 0.706, test_acc 0.594\n",
"045th: loss 10.82822, train_acc 0.717, test_acc 0.592\n",
"050th: loss 10.78203, train_acc 0.723, test_acc 0.590\n",
"055th: loss 10.74615, train_acc 0.728, test_acc 0.592\n",
"060th: loss 10.70861, train_acc 0.736, test_acc 0.586\n",
"065th: loss 10.66698, train_acc 0.742, test_acc 0.580\n",
"070th: loss 10.63692, train_acc 0.746, test_acc 0.594\n",
"075th: loss 10.62256, train_acc 0.748, test_acc 0.592\n",
"080th: loss 10.59369, train_acc 0.752, test_acc 0.594\n",
"085th: loss 10.54309, train_acc 0.760, test_acc 0.605\n",
"090th: loss 10.52803, train_acc 0.764, test_acc 0.605\n",
"095th: loss 10.51590, train_acc 0.766, test_acc 0.604\n",
"100th: loss 10.50317, train_acc 0.767, test_acc 0.609\n",
"105th: loss 10.47712, train_acc 0.772, test_acc 0.607\n",
"110th: loss 10.44771, train_acc 0.777, test_acc 0.604\n",
"115th: loss 10.43361, train_acc 0.779, test_acc 0.605\n",
"120th: loss 10.43110, train_acc 0.779, test_acc 0.615\n",
"125th: loss 10.40474, train_acc 0.784, test_acc 0.611\n",
"130th: loss 10.37671, train_acc 0.788, test_acc 0.611\n",
"135th: loss 10.37927, train_acc 0.787, test_acc 0.619\n",
"140th: loss 10.35974, train_acc 0.789, test_acc 0.619\n",
"145th: loss 10.33949, train_acc 0.793, test_acc 0.611\n",
"150th: loss 10.33469, train_acc 0.794, test_acc 0.619\n",
"155th: loss 10.32647, train_acc 0.796, test_acc 0.615\n",
"160th: loss 10.31313, train_acc 0.798, test_acc 0.623\n",
"165th: loss 10.30494, train_acc 0.800, test_acc 0.623\n",
"170th: loss 10.28598, train_acc 0.802, test_acc 0.625\n",
"175th: loss 10.27716, train_acc 0.804, test_acc 0.615\n",
"180th: loss 10.27948, train_acc 0.805, test_acc 0.623\n",
"185th: loss 10.25671, train_acc 0.807, test_acc 0.629\n",
"190th: loss 10.27322, train_acc 0.803, test_acc 0.625\n",
"195th: loss 10.23416, train_acc 0.810, test_acc 0.625\n",
"200th: loss 10.24615, train_acc 0.809, test_acc 0.623\n",
"DONE\n"
]
}
],
"source": [
"n_epochs = 200\n",
"batch_size = 64\n",
"display_step = 5\n",
"save_step = 5\n",
"\n",
"n_batches = int(xtrain.shape[0] / batch_size)\n",
"losses, train_accs, test_accs = [], [], []\n",
"\n",
"ylabels = np.argmax(ytrain, 1)\n",
"yweights = class_weights[ylabels]\n",
"ytestweights = class_weights[np.argmax(ytest[:512], 1)]\n",
"\n",
"for epoch in range(n_epochs):\n",
" avg_loss, avg_train_acc = 0., 0.\n",
" for i in range(n_batches):\n",
" batch_xs = xtrain[(i * batch_size):((i+1) * batch_size)]\n",
" batch_ys = ytrain[(i * batch_size):((i+1) * batch_size)]\n",
" batch_weights = yweights[(i * batch_size):((i+1) * batch_size)]\n",
" batch_xs = whiten(augment_data(batch_xs))\n",
" \n",
" if (epoch+1) % display_step != 0:\n",
" sess.run(optimizer, feed_dict={X: batch_xs, y: batch_ys, \n",
" is_training: True,\n",
" weights: batch_weights})\n",
" else:\n",
" _, curr_loss, curr_train_acc = sess.run([optimizer, loss, acc], feed_dict={X: batch_xs, y: batch_ys, \n",
" is_training: True,\n",
" weights: batch_weights})\n",
" avg_loss += curr_loss / n_batches\n",
" avg_train_acc += curr_train_acc / n_batches\n",
" \n",
" if (epoch+1) % display_step == 0:\n",
" test_acc = sess.run(acc, feed_dict={X: whiten(xtest[:512]), y: ytest[:512], \n",
" is_training: False,\n",
" weights: ytestweights})\n",
" losses.append(avg_loss)\n",
" train_accs.append(avg_train_acc)\n",
" test_accs.append(test_acc)\n",
" \n",
" if (epoch+1) % display_step == 0:\n",
" print \"{:03d}th: loss {:0.5f}, train_acc {:0.3f}, test_acc {:0.3f}\".format(epoch+1, avg_loss, avg_train_acc, test_acc)\n",
" \n",
" if (epoch+1) % save_step == 0:\n",
" saver.save(sess, \"fe_challenge/fe_cnn-NAG-whiten\", global_step=epoch+1)\n",
" \n",
"print \"DONE\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fc17c67a3d0>]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(losses)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fc188343390>]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(train_accs)\n",
"plt.plot(test_accs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## EVAL"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<tf.Variable 'CONV1/weights:0' shape=(3, 3, 1, 128) dtype=float32_ref>,\n",
" <tf.Variable 'CONV1/BatchNorm/beta:0' shape=(128,) dtype=float32_ref>,\n",
" <tf.Variable 'CONV2/weights:0' shape=(3, 3, 128, 128) dtype=float32_ref>,\n",
" <tf.Variable 'CONV2/BatchNorm/beta:0' shape=(128,) dtype=float32_ref>,\n",
" <tf.Variable 'CONV3/weights:0' shape=(3, 3, 1, 128) dtype=float32_ref>,\n",
" <tf.Variable 'CONV3/BatchNorm/beta:0' shape=(128,) dtype=float32_ref>,\n",
" <tf.Variable 'CONV4/weights:0' shape=(3, 3, 128, 128) dtype=float32_ref>,\n",
" <tf.Variable 'CONV4/BatchNorm/beta:0' shape=(128,) dtype=float32_ref>,\n",
" <tf.Variable 'CONV5/weights:0' shape=(3, 3, 128, 128) dtype=float32_ref>,\n",
" <tf.Variable 'CONV5/BatchNorm/beta:0' shape=(128,) dtype=float32_ref>,\n",
" <tf.Variable 'CONV6/weights:0' shape=(3, 3, 128, 128) dtype=float32_ref>,\n",
" <tf.Variable 'CONV6/BatchNorm/beta:0' shape=(128,) dtype=float32_ref>,\n",
" <tf.Variable 'DENSE1/weights:0' shape=(18432, 1024) dtype=float32_ref>,\n",
" <tf.Variable 'DENSE1/BatchNorm/beta:0' shape=(1024,) dtype=float32_ref>,\n",
" <tf.Variable 'DENSE2/weights:0' shape=(1024, 1024) dtype=float32_ref>,\n",
" <tf.Variable 'DENSE2/BatchNorm/beta:0' shape=(1024,) dtype=float32_ref>,\n",
" <tf.Variable 'LOGIT/weights:0' shape=(1024, 7) dtype=float32_ref>,\n",
" <tf.Variable 'LOGIT/biases:0' shape=(7,) dtype=float32_ref>]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainable_vars = tf.trainable_variables()\n",
"trainable_vars"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"weights = sess.run(trainable_vars[0])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x648 with 15 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Filters are in shape: (3, 3, 128)\n"
]
}
],
"source": [
"weights = weights.reshape(3, 3, -1)\n",
"\n",
"plt.figure(figsize=(12, 9))\n",
"for i in range(15):\n",
" plt.subplot(3, 5, i+1)\n",
" plt.imshow(weights[:, :, i], cmap=mpl.cm.gray)\n",
" plt.title(\"{}\".format(i))\n",
" \n",
"plt.suptitle(\"CONV filters\", fontsize=16)\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print \"Filters are in shape:\", weights.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.15rc1"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment