Skip to content

Instantly share code, notes, and snippets.

@kylemcdonald
Created September 4, 2020 19:56
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save kylemcdonald/d58b423e1995e351aa85fb1fd42ea145 to your computer and use it in GitHub Desktop.
Save kylemcdonald/d58b423e1995e351aa85fb1fd42ea145 to your computer and use it in GitHub Desktop.
Multitask Learning in Keras with an augmented Fashion MNIST.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multitask Learning in Keras\n",
"\n",
"[Multitask learning](https://en.wikipedia.org/wiki/Multi-task_learning) (MTL) involves solving multiple tasks simultaneously, using some shared parameters.\n",
"\n",
"MTL is related to missing labels. One way to look at MTL is that we have some samples with one set of labels, and other samples where those labels are missing."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>img{image-rendering: pixelated}</style>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"\n",
"# helper function for displaying images\n",
"import PIL\n",
"from IPython.display import display, display_html, Image, HTML\n",
"from io import BytesIO\n",
"def imshow(img, size=112):\n",
" data = BytesIO()\n",
" PIL.Image.fromarray(img).save(data, 'png')\n",
" display(Image(data=data.getvalue(), width=size, height=size))\n",
"\n",
"# render images pixelated\n",
"display_html(HTML('<style>img{image-rendering: pixelated}</style>'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example we prepare an augmented version of [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist). We randomly flip the images left-right and/or upside-down. This gives us a dataset of images `(70000,28,28)` and two outputs, the one-hot class `(70000,10)` and the multilabel flip description for each flip axis `(70000,2)`.\n",
"\n",
"We mask some of the flip labels to -1 to mask them. This is what we might see if we had missing labels in a multilabel classification problem.\n",
"\n",
"Then we \"split\" the dataset into two separate datasets by completely masking every other sample. To mask the multiclass output, we set all the values to 0. To mask the multilabel output, we set all the values to -1."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def build_multitask_mnist():\n",
" (xt,yt),(xv,yv) = tf.keras.datasets.fashion_mnist.load_data()\n",
" \n",
" # lump everything together for processing\n",
" x = np.vstack((xt,xv))\n",
" y = np.hstack((yt,yv))\n",
" \n",
" # create labels and augment accordingly\n",
" n = len(y)\n",
" y = tf.keras.utils.to_categorical(y)\n",
" y1 = y.copy()\n",
" y2 = (np.random.random((n, 2)) > 0.5).astype(np.float32)\n",
" for i,(lr,ud) in enumerate(y2):\n",
" if lr: # flip left-right\n",
" x[i] = x[i,:,::-1]\n",
" if ud: # flip up-down\n",
" x[i] = x[i,::-1]\n",
" y2[np.random.random((n,2)) > 0.5] = -1 # mask random y2 labels\n",
" mask = (np.arange(n) % 2) == 0 # create mask for every other sample\n",
" y1 *= mask.reshape(-1,1) # apply mask to multiclass by zeroing\n",
" y2[mask == 1] = -1 # apply mask to multilabel by setting to -1\n",
" \n",
" # train-valid split\n",
" valid = 10000\n",
" xt,xv = x[valid:],x[:valid]\n",
" y1t,y1v = y1[valid:],y1[:valid]\n",
" y2t,y2v = y2[valid:],y2[:valid]\n",
" \n",
" return xt,y1t,y2t, xv,y1v,y2v\n",
"\n",
"xt,y1t,y2t, xv,y1v,y2v = build_multitask_mnist()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## View the data\n",
"\n",
"Here are a few samples from the augmented data."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAACGklEQVR4nM3S7UsTAQDH8d/dbvO25e6mS21ua6JT0m0Jy0hxaEpRyxchRLmoYL0xUkhKYmUvUsog8o1IEEFlZVQERVFCTBe9CGerjB7owblo1VI2557avN1dr3vXq+jzL3y/wH/H2CDp6QU2dR3aB9RX0OXqPBY2VL5sBYUX2aTAbnQk1etv3jW7Aya6JHT0wOyChm/1EHi05kNEVasjSHHSKI3sfy642tonGWXQZ2+mIEA5q1eaFrl0WV7Q38JY2I/4USSe67ODAr2crru9dpdx4rphD2uwWbVFg9fWEZ4tJ1IAskOqK6faXG/OgkGB7+JoZwfOZBeozHQ/KLhHsl7yweaS+Pzu96lpJsEeQ6rbnwuHBwCAXl1WXf7wu1l81nTj3k7LEbGjnwX40wAFaAMghXdTrlV6jyJTu8Lj2NYFRgwJAIVux2MBkXD19oLmWHRJndyxrLv6qkmlKM1PEBixS2U5KsOJ6nhCkqMEkcnI+NzcSm1gK4U4H67kOJk8GY1INTTSS4mswOfLqYwcwHkfQCpkfVUnxeGkHpaBzku935xQFesBXBgDYAXuWN0kAdq5VyfxzdQD0IMEZ1dibGbooNcwSIjew6HREO9XATWvj4PC51tq27jfafo0DN6mCLR/nUMMpah4woFAT2GL9ufbL8WamHyxMeirq5I9jc43bKAvT9wngML0rz/qNxqjNebxwNQ/nu5v/AbjrcvMeRweBwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 0 0 0 0 0 0 1 0] [-1 -1]\n",
"class: Bag\n",
"flip lr: unknown\n",
"flip ud: unknown\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABP0lEQVR4nOWRMUvDQACFXy69JiQclpijXUqKVIpS41IspZNQcBFxcRcnV3+Kq46OLm6C4OAkSqEIWaRSkZZgqRwNoeVqyOmgCLoLgt/2HnzLe8D/QPuR856QAwDANjshXzUBAHf5bmsTANCsFb+ZBueVm1b9eXr0Qqp16d1qn5ICY+7G+PjQ6ViFai9ALmbdDIiCAuCXXC7bb+cHfvbyKW9ySpSTgQIIL7NCL41Yp1idEwssDP0EzEIGLisR8TiiCVkZ75xGTZVeWF5rBJtA83mMiC5VzgK4+3zoWSJ42A3uDSjKtb2spPoYq56Ues28bre7OqktprZD5q+0xtrMQK6z3phg0B/2FyomsUU8ceLQ7mlA1SnbGMVUzqx81ojCWFCapDq1xMcIlOQcm0rdEDE11atKkE7TdPZLl/053gGO0m6gw/AjmQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 0 0 0 0 0 0 0 0] [0 0]\n",
"class: unknown\n",
"flip lr: no\n",
"flip ud: no\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAACUklEQVR4nGNkgAJOMc6/8q9k/v277TiPgfE/AwMDAwMLVE467B+Tz46vDIGPX2tzT4bIwSX513Awn/Na9/WbyYlJrtxfGVBAxkJFPT5LdS11P4HCmwwMjCg69S68NRK8zf7r2R+l2ysFPkDMZYJKTlzHFhXwzIL/tOCf2wfZoDohJANj2OV/n1jZhE5mfN5voDzrJ8S5UGMVjPbLMEg7yjDNUBWoPPILahwjAwMDA+N/VvsrnNzXve+x8J9R/cjyDNm1jAwMRuqz9OTFpJn/14jJr5SG2gZz0Necj+9ffWT6q8LFrvrsLwNqIHBO4v4q1vZ95u1moa9v3zL8Z2D8D9X5n4GB5cUjBkGx929ORv6/epQJIgYHisXhYozdUgwMymLZJzlR/ckgm/p2Ctc3I3f7eIZ/rI9RHctgIMBgai7NsLCBwUvZVhrZnwwMPq8/cLKI3L7G9Vfito6McxMkhJgYGBh8TJltmb79eH/wbn0qg+XxZ99P8drIwiQN3kmJvxH99EiOect57jPBgp9dpZ9ExUGNZTR+xKLoybD4SdCzY/x/Ipez3ws/YrH3OzQQ6n5POabw4j2TkPKTJ5ZhHHMLN5rt/gk19v+OY3f8LiqLMMyT8Pka8YT59y11uZ8Iz/BOa+ATYpRiaNFhYOAWMo3L2oUUtuorr4uqa2+3fvKqk2H3ud8snHlIsXLmhyu3rc8VKfWQv49KtMVeLL2BHGUnV7F/LbRW8rtVZbiJTebYC5QQYmBgYHBzaA9g+PznzAu4CBNC8tyJz/fMHjx4gdAAAPoo0DPV04jbAAAAAElFTkSuQmCC\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 0 0 0 0 1 0 0 0] [-1 -1]\n",
"class: Shirt\n",
"flip lr: unknown\n",
"flip ud: unknown\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABy0lEQVR4nGNgGDaAkcHlwS/JL8xvPzFxStz6xvPjD0Sc7RcDAwOjbdev22Zv+FUuvbY9+VeJ7wWbzD4fllnGn/6IvCtg8nvfLslyt+w6h8CyQ6uPvfzEcFBy+1r7zx8kuI8yMl02+t/236tnidyrq81ezFe47/KwvLlxnO3dt406jIwtAjkbvx1tmin2Ro5dziTqQ7Hc2z9/2tO+nLM3M2L6ppqp9+HGe/azAX8kn6QrirA+fvmdY8LRDUmiN7+zPNXwe/qq+shr328CZ+7qOd7Zoc72+r69rvBr9p9/WEzkJqkb84kzPNqkdYjT64880wsdbiPJl59V7kiJML1e4M32SPjzru9hv/zNFna8/Xb14X+mNzvsGO5vfcX0mF/a+u8L9ei4Bw56Imka776w/vl2c27c0ocK3tZM1k6nf8l/ZRDc9ufN7wdfbR/LRfCrcvPZ2/3kO3OLacu115fZ9tx/Lin++6FG3TNFnkdvH320WMqjyHqCl3Gexguzv3tlTH8/eSH35Z8Q228O5vMiSgdN2M6xtbNsZZ1rqcJ8ZKvuKUa924KM4l+ZmLj3s/9/9lzCRY0RLZZM2d/zM7J9+ff1u9KX0/ROIpQBANVoxvFUYMU1AAAAAElFTkSuQmCC\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 0 0 0 0 0 0 0 0] [1 0]\n",
"class: unknown\n",
"flip lr: yes\n",
"flip ud: no\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABlUlEQVR4nNWRy0obYRiG3+/P/DNJJnGSSTJpo5IEsvAEiggFBaXgwkJ3tl5AoLRi6QW48xIsDboQdOVGXLgSFHTnwhYKiniiHpq0pok1TOJhksyhCxcG6gW0z/ZZvc8L/H8QEHK7vQrJBAUgryQEuQgC3IKSuE1fX//qKi35ZF6J8NRhfT9agC54VEkPlQan5coJAK1sf5gPxZtyP9XqjycUvgTQV7uLBbxkV38bObPEg8oBWrLBK0JSa0mIbc893wNccoVMyyDTutVOLUt/R8hYRX1zl02MSFHJdtx3zGQOoxtf9eY1AerAYGvkqGy8kGskc7vusZynk6b80t9D91u0se4rf7InK4llBU54JhdbT6eJgQQOFD69UYeWp8hXUbmXieGd2NuoAbpvwGtA3+ZCvvnVJeN1l2dxPFvLj1BjqNLqWbiX+cmwgnoxsPWeNSbs6B/9eky2GW+2BcXex4N0gIvEl4/fPp+fz40l2o9S23/1n91b2TgEgEim0QgAOLA6vzYMF5Me/67zGUAAE+kR/y/zB+3XhGtlltTtAAAAAElFTkSuQmCC\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 0 0 0 0 0 1 0 0] [-1 -1]\n",
"class: Sneaker\n",
"flip lr: unknown\n",
"flip ud: unknown\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABXUlEQVR4nN2Rvy8DYRzGn+/de+ldq9ecH01bGj9Kol2ERAwisYjYVNgZRJD4N2xGm3/A5B8wkYhNDARBkIij6a/rVe967/saMIiYLOJZnu1Jns8H+Dehj1ZYXzFtJ+6YsBNPqceM3d1gFwR0TUzGrLovyoYTkWG9mSnF7zoLmpvcJKzHxSvT7rUWHSG7UK0HRQRaRQY88sDyI5Roj4ZFI0a+aUrpebwUbXVrDqV72W3P85lN9Zmk55HzogacQnrzMQlPqYUJy11eXF6H1ExmoAJDmFWVCVdlgcbHCACodXako7Q3lI/5aMKSXDiC6bnk5xUMr1ju8ehgTfSpEkHAVWhZAkkiEAespanz/qgZ5ULnFVUx9tcUkpBCcAJKW4s5We5lDQ1K1mg4BSjyfVUCgL0xnr7yDKXtdHVbpFywLzAvp3f8J704fwKYg7vfWM8dHd5EfjSxcGD9yuQfzRsmY4l/x/DWvwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 0 0 0 0 0 0 0 0] [ 0 -1]\n",
"class: unknown\n",
"flip lr: no\n",
"flip ud: unknown\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAACI0lEQVR4nFXQy2oTcRQG8G/ut0wmaSdtqr2oqTWo1YWK3RQUfALfwIUv4FLd9x3cuBdcCy4ExYXULlpEUQgEhbRJZqaTuf7nHhe2M+O3PD8+DucAAO48X8J/2d3bAgAWAHbuDd7gahpKci4m5CTB4zgs8UN3gEfPRkxKp1nO6E+DhXejEpuyBm7fxKBjj29FRxs/hAYfA6ABIPMuw8vDoC8K/YuHSxrokMI5CrmEDmjOINx85F4/omgmKjF1RUhOmzgLPjvsSmTTP6ZLnJk99BWL0R2jMV0jcy302RJFPqVzLtYtZnPm98aMNm7lJYYp1iwqMfjCmXctQ1pu0hUqsX3bV/MYkzRWXKBjcygxkb2b2URaZrd0g+MS9tqUrjBy1d0pyDGGgRp4E3T4uELVbUoeRyglkknKoT3u+Sjfl65ijWG4NGydcCsknZFpu2p6q86Ftq01imQligmb9I6FqklF4wOXo12eKdSIVxKT/lY1ffvKgVBErJUbNBvPtQHJKsy9S7/bYSt7qBcbWaJo76HXdjKqw4TFnwfC6IUfp0owrN2J+6c2TwehacSnSk4rSpHU0NjfMPwGOm3TFgSHyLFYw59iKMKYqyNFjqw+QwSnhmi5eoYOFa4HZtdrLTasGn6WFZ+KSBZItkfFTRFeDaevZ0TQtz++/WSniUCGh/+Qwlme7HwvFrcV9F56X82zGXuOX9Rf6eTVjZO71rpmUnMAwF9XAgLPPbIScwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 1 0 0 0 0 0 0 0] [-1 -1]\n",
"class: Pullover\n",
"flip lr: unknown\n",
"flip ud: unknown\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAACOElEQVR4nHWSz0tUURTHv/e+Nz+aeTpjTjZpgmgUVIswiX7hqiSwdtEiEKKSalcLN7VoIRW0if4Co8hlG2lhRCkU1JC0qLDC8QfZ6It52vx+7753z23x3ji66CzuPdwvn/M953AZ/Dh6cV86AsZVwck9/IRNEcawIqJq+e/aWq3k0AVEAQAcABhOFBbW7VAkGjOUWPpzMoB0ABBYLpdszVWSSAPiYTgNkSkrFopFJBF5OonIT3C5qWxeMcE4mMaVZLQCttmzRpwICmCKcbkWeHL/KBJ3JcC5phjzTP/Z9+QQSgMD8ziT0LwKVIMkrJOCYwtGbiytXBPUIAlz1WjKbbJXk83fFjoLxa1khb8aSt36sPf3gnXOdBCIAQmrPTz5OXX83vzEjVxqiwhgsX9n4ZSY7jiS7UJPfec8uD1zJvnC6rLexJ4kzbrok7r3pbdlYKS/29h++PbB99Ak/OUAOnNx+pmJA4u7VKV5xl66AnAKynouQtccZ8gdfTQx9Pbpcpvfoy+y4dnimVI6HtUyfR27j/04W5i7u9HQ1APDye8Ydyv729LL3qXFWZG4OTUYiOPlglfLv7w6YjkrxrvHLWEhlnovBw1BzCfNpLPn47Z2oySb1hVxle+rzzmWKCphZHs6a6vCtcjWhHGnPgq6X9uISR06k2CSC4rn+jaWMJ8Z/OVwd0VUWkOJiKBo/DqYCkhg8pBeLcabW4qy6ibDuftjjU/NgfbRzNfnva2tA9+z0+fx32CN9B9J/f/QbyRHBAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 0 0 0 0 0 0 0 0] [ 1 -1]\n",
"class: unknown\n",
"flip lr: yes\n",
"flip ud: unknown\n"
]
}
],
"source": [
"classes = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']\n",
"labels = ('flip lr', 'flip ud')\n",
"for x, y1, y2 in zip(xt[:8], y1t, y2t):\n",
" imshow(x)\n",
" print(y1.astype(int), y2.astype(int))\n",
" if y1.max() == 0:\n",
" print('class: unknown')\n",
" else:\n",
" print('class:', classes[np.argmax(y1)])\n",
" for label, value in zip(labels, y2):\n",
" desc = 'no' if value < 0.5 else 'yes'\n",
" if value < 0:\n",
" desc = 'unknown'\n",
" print(f'{label}: {desc}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build the network\n",
"\n",
"Next we build our network. `shared` is a simple convolutional network that represents the body of the network. It extracts features for the two \"heads\": `out_class` and `out_flip`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.models import Sequential, Model\n",
"from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout, Input\n",
"\n",
"# build the body\n",
"input_shape = (*xt.shape[1:], 1)\n",
"shared = Sequential([\n",
" Conv2D(32, kernel_size=(3, 3), activation='relu'),\n",
" MaxPooling2D(pool_size=(2, 2)),\n",
" Conv2D(64, kernel_size=(3, 3), activation='relu'),\n",
" MaxPooling2D(pool_size=(2, 2)),\n",
" Flatten(),\n",
" Dropout(0.5),\n",
"])\n",
"\n",
"# build the entire model\n",
"inputs = Input(shape=input_shape)\n",
"x = shared(inputs)\n",
"out_class = Dense(y1t.shape[1], activation='softmax', name='class')(x)\n",
"out_flip = Dense(y2t.shape[1], activation='sigmoid', name='flip')(x)\n",
"\n",
"model = Model(inputs=inputs, outputs={\n",
" 'out_class': out_class,\n",
" 'out_flip': out_flip\n",
"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom loss functions\n",
"\n",
"We need custom loss functions to handle the two independent outputs, as well as the missing data.\n",
"\n",
"For the case of the multiclass output with softmax activation, we multiply the loss by 0 when the entire row for `y_true` is equal to 0.\n",
"\n",
"For the case of the multilabel output with sigmoid activation, we multiply the loss by 0 on a per-label basis when y_true is equal to -1."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras import backend as K\n",
"\n",
"def masked_categorical_crossentropy(y_true, y_pred):\n",
" mask = K.max(y_true, axis=-1)\n",
" loss = K.categorical_crossentropy(y_true, y_pred) * mask\n",
" return K.sum(loss) / K.sum(mask)\n",
"\n",
"def masked_binary_crossentropy(y_true, y_pred):\n",
" mask = K.cast(K.not_equal(y_true, -1), K.floatx())\n",
" loss = K.binary_crossentropy(y_true * mask, y_pred * mask)\n",
" return K.sum(loss) / K.sum(mask)\n",
"\n",
"model.compile(\n",
" loss={\n",
" 'out_class': masked_categorical_crossentropy,\n",
" 'out_flip': masked_binary_crossentropy\n",
" },\n",
" optimizer='adam',\n",
" metrics=['accuracy']\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model\n",
"\n",
"Finally, we run `model.fit()` for a few epochs. Keras will report the loss for each output separately."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 1.7379 - class_loss: 1.2258 - flip_loss: 0.5121 - class_accuracy: 0.3699 - flip_accuracy: 0.5132 - val_loss: 0.9782 - val_class_loss: 0.6575 - val_flip_loss: 0.3207 - val_class_accuracy: 0.4422 - val_flip_accuracy: 0.5451\n",
"Epoch 2/5\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 1.1119 - class_loss: 0.7560 - flip_loss: 0.3559 - class_accuracy: 0.4160 - flip_accuracy: 0.5201 - val_loss: 0.8651 - val_class_loss: 0.5776 - val_flip_loss: 0.2875 - val_class_accuracy: 0.4471 - val_flip_accuracy: 0.5572\n",
"Epoch 3/5\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 1.0291 - class_loss: 0.6966 - flip_loss: 0.3325 - class_accuracy: 0.4244 - flip_accuracy: 0.5287 - val_loss: 0.7857 - val_class_loss: 0.5177 - val_flip_loss: 0.2680 - val_class_accuracy: 0.4640 - val_flip_accuracy: 0.5498\n",
"Epoch 4/5\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.9780 - class_loss: 0.6581 - flip_loss: 0.3199 - class_accuracy: 0.4310 - flip_accuracy: 0.5327 - val_loss: 0.7884 - val_class_loss: 0.5157 - val_flip_loss: 0.2727 - val_class_accuracy: 0.4633 - val_flip_accuracy: 0.5342\n",
"Epoch 5/5\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.9405 - class_loss: 0.6299 - flip_loss: 0.3106 - class_accuracy: 0.4360 - flip_accuracy: 0.5289 - val_loss: 0.7445 - val_class_loss: 0.4824 - val_flip_loss: 0.2621 - val_class_accuracy: 0.4609 - val_flip_accuracy: 0.5294\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7fe408a00b50>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(\n",
" xt[...,np.newaxis], {\n",
" 'out_class': y1t, \n",
" 'out_flip': y2t\n",
" },\n",
" validation_data=(\n",
" xv[...,np.newaxis], {\n",
" 'out_class': y1v,\n",
" 'out_flip': y2v\n",
" }\n",
" ),\n",
" epochs=5\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## View the predictions\n",
"\n",
"Looking at the output, we can see that both tasks have been learned despite the missing and non-overlapping labels.\n",
"\n",
"\"flip lr\" is predicted randomly for left-right symmetric items. But Fashion MNIST originally has all the shoes pointing to the left, so the network has correctly predicts the flip when they appear pointing to the right."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAACDElEQVR4nLWSu2tTUQDGv/O4NyePJq2mBNNGMY1ChBaFRmwHQUGqg4uLo4uTuw4O/gcOToIIOhcRcaio4OBi8YEPxLe2Yh9J07S9Te7NzTn3nOPWJuLi4Lf++H7wwQf8pxACAJwAp/oA0oOYBrfGAofFsfjNb0x3Qa4TtAUgd34wODD84w+YmZh88bw8Oq5eTxUKlRuyR3v8wp2KU7VuqQLqbFx6TGwXxOLcA77vEBVwucX8EQJsYY7Zk0OGGcpEJ+jE19FVBEfBaOYbS9/uyDgs0bOSjsfeuX4yZjK+dGSYT4FtT6X8w/2dwqiYLCQh6K6zYNtiPlu6uIA0UWFRtKgJ30PyaKsJb4wYBSOCuuz4wQwQATTmcgJQnPFdBBDa9yMTNdovrx0FTEdGloHgY05l6NPWZOA2pMq3HSM+313+srQGgCDYbLUTfc0i5teb49PluHUD3q/ZRv3KM4KfTDaV2/ha3Cvy1x/dXvWAuEaC8aUJXuCKpKXZPfzqoT8X3qsaP1/wFDqbrZFzfD/RloZMO1On1+ZHFzzBvkfxTdlKpn/V+MGqrbYd7mSrA/HyqnBVFFuJVCrbqGcXeY3PBX426Xp8mVGeIFFSyLCztJYu3XpDMD1WpRI5C6ukUsTqSGrWP1i7OgMCTJZPmIRDoFN0iOQsi7TxVi4/AWjP2wCMuHsWBuqf/vrTf8hvFHHmGSwKvBEAAAAASUVORK5CYII=\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"class: Ankle boot\n",
"flip lr: yes\n",
"flip ud: no\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAB90lEQVR4nF2QvWvTURSGn3PuTUwaSVMtlLZDBT+g0CragnQQwcnFQRAcdHNy0KngXyAIDg5O4qKuDrrp4ChFOthBsNrSxZZ0MMXYD5tffvfe45CkX2d9eF7e9wgAYsD487cL7Xzi5srTJvsnwMUnyz/WGmZmPxfT+qvJPQDVN+d1ezdPvn8nGZTKxc939+CnsY3kg6C5ooDY8PVF8ABTYw3vyiN9mvsohbC1Foj3Zrvm7MNGcvFFfW1kXfPi8UsPGr7qT4EC3ArOSuFlPvXhcnlocPNZ9KV/I+e6sRdW3TGqfNwZn313w3+dCpWYfs0s4YHJ30FdeYOJbPix5DJDfTSl1pXXeOBReTuWW2H65InCUN4q1m4P7PZLcboTOzd0plpZjl9Sis5L1K2litP6+97OgbP3r672NwsOEG31f7sDdAvxZz67ZsWKS4ikUrs01/mrAkgR23LJujmOJk560HJW/voMExCyApuo9SDKbtuCExNRbXuJB2IxUjRTMRVRS5qwfQiMijMRAM3FHTaJFBERJBWCFY6YLRc0mUvmfaR20AQSCKhgYqHcMX0PKlgHqoU+DnzIAAcmOEw0Hiok0BYhSVQwie5IIdRQFaTjH55SJ2RZ3s6ydo65I4VqFT+oWoTgVvtOo2kPirHwvVlAt01CymvzJID/CjnVF438ZwEAAAAASUVORK5CYII=\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"class: T-shirt/top\n",
"flip lr: no\n",
"flip ud: no\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABP0lEQVR4nGXRsU5CMRTG8f85LXAFEiSiAXExvoCJibuDu6NP4urkGzg5OfoEru6uTu7EaDQYApJ7c+k9DnIxaTs1/fU77WlhPRqXxwAMr/bqJXw9uTi9mRTzZnPQmj4sYiwfe+fdfvHxdr21SJL56F5O5sX3bdFLy4rv3e0cfk70YJGimR8vX/wurFL0JrTbhGCaYqYeLDgy2eBmm1NzKk5FfIoUGHh1S01xUHhQKkeKWSeXClDE9WPMZw4VDKwo0jNbVl/NYtSgYBVQ4WJsmomIoxKzLEkiAIKJSIzeqQkY8F+2fo6e8NeKWEiSmQX+6oo24qSUAQUM9/+4dbJTrrurjGaMR+V6aitL8HMlKEiFsh/jsltRigTnrJHH+NoZwSp457f1ucZNT+2zsQ/48PP1lMcoRjYcyux9ClJ/yy+0/mgcKyN6uwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"class: Dress\n",
"flip lr: no\n",
"flip ud: no\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAB3ElEQVR4nG3QQWsTQRQH8H9fXiezm2w327CGpUStsSAePHjyoIh48CB+ASkePfgpxLMg+An8Ah5ExFMpIiiIN4MoUiPGEEKI26Sbze5m+vDQTbtR32WG95v35s2sYBH3t+2d518ub9dfvOnkqZV8fXqzl7oBnx20D3RmPXldxMe1Xp3ZJUKXSeLGw70TvH2vrbmElBKXbU4i0g8AgAAAr7puks1mcPzpbPR7oIJnOEE8spUVkekOM+MyX3z/YWmgOzfG8dwoA6dK/vUWipV4SSCxSrBMDFwFLyF+klTCeTJLub7Xh1nGAQ6RJcqUxTnO5RtGj+OZJiUpUW+RPT5lKDNwJpmn6MdflYBCVULRJJD4H0SUDaxxiDAbGsjisqOmOB00oMaxj8RbA0kRgfW3wUS78lHH4XC5kuRU4+u3vrL3Dxw9v/YJRRTcja4Ak1pnen6kzzW7R33zgbYaIUb8S6hXHePzrbxvjmcmVjMRAaNczrJLy0/x00aMshhYhiLd3yziulta6zAbD3xoK5ZmET2vGnzfSE1VJCvbMXEROVD7gwuJcoyKas5Q1YsYTCWMAJqGalpxR+QX0RGSlu9XhhucaXhxpYi7HalRe4dX363aBE1b+G9scgu0+PE/Mbq6BpavQREAAAAASUVORK5CYII=\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"class: Shirt\n",
"flip lr: yes\n",
"flip ud: no\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABZ0lEQVR4nG3PsUsjQRgF8PfNzmT3iJewRMWo4KGNHqTR+torxNJ/QTu5P+RKC2s7e20sbO+uO4xIUBByhRD1ILcrUZPMzLsmUdnJ13zD+/GGGeB1Ng+/o360+wXhfGXOJ2zzIWdjnOnRXjrtPtoM9+hqNGUUqtH+lvcMp6GgSnm+U8BGEqGkMYVnYbxVuPZyI/e9yp+0Y3pZclNoXuiOZjbkc6TM9EkBW/ZKix+8eFDz9ygcPyz+1+krp9JMrKSzhWb/tucAcRCv2yggfsYvHnCgM78CPFgRByHo5SzAH3fGgw40i60A8VeT9KAaXIfYjhwIehATUFuKJ9+hfj3dR+I9CPgJzbJQSHkfvTUrjDQEgnhCM7XO0g2GHnMhmogG1iRoz4TIuKrmPieKpXqIVbqF4/X60OvFECuWH9AslcWthfg01NknMI7scviVmktVY7/zUfq1oFl+TLqr6d584/zBIZz5IPkPk6WXa2tDaxAAAAAASUVORK5CYII=\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"class: Dress\n",
"flip lr: no\n",
"flip ud: no\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAACL0lEQVR4nE3PvWtTYRQG8Oec933vzU2aJrltaVpUihWsWApF/JgUwUVFBwcHBwUXJ1eH/gGuKjjqJK6ik6BS2sGiUHSwih9VpBD6YbSNSW6Se/O+x8Hc0DOeH8/DOQQAIAGC2yeWdvyp4oP5rd4GBABA6ebJsDW2x6D+Oc52Xt2rgiTFSzeycdz9FfuqEfrKZEYvfO/jwNMkdpCkFvu+z4aSsHoZADQAXPTqqqC0araVYmcjLyqPrad4xisOQwkPWnLCHd2My6U+5oYa7eDZh47JZXITk+OZwKmZTynup9HKNA7k2s1WKwKm5lYeV9B/xR8auD/5sSZkikODcrXAp/lWLxkuPSl5d/KH2Uu61Fitnb/2ovqyn7zeOFc+VmrGVpPyDd5vNsfN0V5y9uHCqbvP6yarYVuJ3bc1c3Z1Pv3zy6Ok8OZKma0IMSytvJ5tHUqxfHAnH43UTCTEVrt8tLGXwhTHvRBV5ZFiYVbWZeu27QEAA8iLbW+ycglEIOyCbUecorJMHYhlZnIEBBtOIa3NgHYC6nZhhbpKxGwTJE2S098CjsWJgJlcBr1hAEr0GlMsikjghAIw1VM01vzU4kgDYgRk8JfFS2tFf81bQeKInHOiUNEiKRJvFBCYbOD5gW+EUWFJr9WOKseb6x0RcQo6n8XakT562muP5XToBKL1b99HHUQ9ZNvF8sQPzSIWyo4sg7qx6mGRBIuL2DVF5qD1H99NL+wWgrwdjv4A+Ad8Xu17Oute7gAAAABJRU5ErkJggg==\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"class: Pullover\n",
"flip lr: yes\n",
"flip ud: no\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABQ0lEQVR4nOWRv0sCYRzGn+/d+554pyZoDiU22OrUf9DU1t6fELQ1NrbYGBRBW0QNUhHRUjRFVBCBQ1EZKJGC2KKed+/53r1vQxBRa0PQZ3x+8AwP8G8gDgAGAAYyQfQz0pwHALBv8pIFYLd6X1+ZA4AvRQKOKltA97WRyhRPLis1gPSnPVWUCzksVjdvrs6vX97W0gBgwQBAyLnisBRvdnKuHfd6bDyzs3cGwFQaNLldWF/GwUTQj6UsGGizZL69sfqxWZ72RvOuftJkxSwt7UjormO5LfF4ekvl2V6MaVNJAR5oYYSGo0PNpaX9fXY8EyZVnJjlJKDCYWR0lDJDf1C7u3ggkJOFHsK20yPpTCJrKC1ajed6AIdTqcCTXshl3xdCDiOhKAIXNoPuDQhjjhmZMtRJLhVMFUEGXiT9X7nt7/AOEpeHEnFSTpAAAAAASUVORK5CYII=\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"class: Sneaker\n",
"flip lr: yes\n",
"flip ud: no\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAACEklEQVR4nFXSy6uOURTH8e9ae+/nvTic407J5T+QlGRgIklJlLGhCRmaGcrQzMTUXKTMDAxQUgYo5SiXjstxfc95z3PZey2D5wi/0W592pPfWsDRha+nCfxNZP+Js9K/j61MzxH/wTXc988ACueX0l4KGmNKVVRY5v2k3QnAJZtfmuf/5G+LPgfKoXrQ7gFYt2nzls1DAEI35QYod0bZOfWo7l48f/7i5Wv//vYqYBMB5TZBuXignoTStY189e0XTqHLa1uIfIJU755PlZpbzLX80usMEvdAAdUiw85KEatCzla1X6qKuz1K7LJRUAklC4NoQYXDPcbSzQxDMSSYq6n9cBuQQKHEPFk3rgARRCXlobSBB6DwTvIw1QNTsWyqpqbVdDj9AIrsbovkpcpX2xGXXU0zeAMofmu8MuoaDbaq3nwjpit98XVls43k5ADumI8rY6bHa00sTdUIXqxISQ5JzYEIj2faFaOYIIigzSiZdE/7n7wPsQTDzczcvPOwLCWvYozfPTohqAYNoaqrerT4bBVvjn1AqlKsBilESU7wUX8mwpO6DS4CIKpC8JaHAELM2xY+ehRxQMQ1L8/lTUIoKJmPx2c71yBgueApTTaehgIC4tw6udgx6sLYplKk3frkIOIggBo3j2xYWhlPF+bW12vGzat96J8yUeBind3d3X9e3tFP/suZPcCOw/SLBfgNLtz6Yy4zlAwAAAAASUVORK5CYII=\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"image/png": {
"height": 112,
"width": 112
}
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"class: Bag\n",
"flip lr: yes\n",
"flip ud: yes\n"
]
}
],
"source": [
"predicted = model.predict(xv[:8])\n",
"y1p,y2p = (predicted[output] for output in ('out_class', 'out_flip'))\n",
"for x, y1, y2 in zip(xv[:8], y1p, y2p):\n",
" imshow(x)\n",
" print('class:', classes[np.argmax(y1)])\n",
" for label, value in zip(labels, y2):\n",
" desc = 'no' if value < 0.5 else 'yes'\n",
" print(f'{label}: {desc}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tf2-gpu",
"language": "python",
"name": "tf2-gpu"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment