Skip to content

Instantly share code, notes, and snippets.

@Shirataki2
Last active March 31, 2018 14:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Shirataki2/50aad08a3dc48c3e33fe68a32b682b85 to your computer and use it in GitHub Desktop.
Save Shirataki2/50aad08a3dc48c3e33fe68a32b682b85 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cycle GAN\n",
"最近話題の生成モデルである、GANのうち、写真を絵画調に変換したり、馬をシマウマにしたりとなかなかすごいことをやってのけるとして話題になったモデルです。\n",
"https://arxiv.org/abs/1703.10593"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import keras\n",
"import keras.backend as K\n",
"import keras.layers as L\n",
"from keras.models import Sequential,Model\n",
"from tqdm import tnrange\n",
"import tensorflow as tf\n",
"import tensorboard as tb\n",
"import cv2\n",
"import os"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instance Normalization\n",
"平均0、分散1にさせて、学習を安定化させる。\n",
"ほぼ写経。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from keras.engine.topology import Layer\n",
"\n",
"class InstanceNormalization2D(Layer):\n",
" def __init__(self,\n",
" beta_initializer='zeros',\n",
" gamma_initializer='ones',\n",
" epsilon=1e-3,\n",
" **kwargs):\n",
" super(InstanceNormalization2D, self).__init__(**kwargs)\n",
" if K.image_data_format() is 'channels_first':\n",
" self.axis = 1\n",
" else:\n",
" self.axis = 3\n",
" self.epsilon = epsilon\n",
" self.beta_initializer = beta_initializer\n",
" self.gamma_initializer = gamma_initializer\n",
"\n",
" def build(self, input_shape):\n",
" self.gamma = self.add_weight(shape=(input_shape[self.axis],),\n",
" initializer=self.gamma_initializer,\n",
" trainable=True,\n",
" name='gamma')\n",
" self.beta = self.add_weight(shape=(input_shape[self.axis],),\n",
" initializer=self.beta_initializer,\n",
" trainable=True,\n",
" name='beta')\n",
" super(InstanceNormalization2D, self).build(input_shape)\n",
"\n",
" def call(self, x):\n",
" if K.image_data_format() is 'channels_first':\n",
" x_w, x_h = (2, 3)\n",
" else:\n",
" x_w, x_h = (1, 2)\n",
"\n",
" hw = K.cast(K.shape(x)[x_h]* K.shape(x)[x_w], K.floatx())\n",
"\n",
" mu = K.sum(x, axis=x_w)\n",
" mu = K.sum(mu, axis=x_h)\n",
" mu = mu / hw\n",
" mu = K.reshape(mu, (K.shape(mu)[0], K.shape(mu)[1], 1, 1))\n",
"\n",
" sig2 = K.square(x - mu)\n",
" sig2 = K.sum(sig2, axis=x_w)\n",
" sig2 = K.sum(sig2, axis=x_h)\n",
" sig2 = K.reshape(sig2, (K.shape(sig2)[0], K.shape(sig2)[1], 1, 1))\n",
"\n",
" y = (x - mu) / K.sqrt(sig2 + self.epsilon)\n",
"\n",
" if K.image_data_format() is 'channels_first':\n",
" gamma = K.reshape(self.gamma, (1, K.shape(self.gamma)[0], 1, 1))\n",
" beta = K.reshape(self.beta, (1, K.shape(self.beta)[0], 1, 1))\n",
" else:\n",
" gamma = K.reshape(self.gamma, (1, 1, 1, K.shape(self.gamma)[0]))\n",
" beta = K.reshape(self.beta, (1, 1, 1, K.shape(self.beta)[0]))\n",
" return gamma * y + beta\n",
" \n",
"def norm(**kwargs):\n",
" return InstanceNormalization2D()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def load_image_from_dir(dir_name):\n",
" img_list = []\n",
" directory = os.listdir(dir_name)\n",
" for _,file in zip(tnrange(len(directory)),directory):\n",
" img = cv2.imread(dir_name + '/' +file)\n",
" img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)\n",
" img = img.astype(np.float32)/255.0\n",
" img_list.append(img)\n",
" return np.array(img_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"tqdmというパッケージを使うと進捗が可視化されて分かりやすい。\n",
"\n",
"カレントディレクトリ下のhorse2zebra下に解凍された画像を置く。\n",
"馬とシマウマの画像はhttps://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/ より"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7f7a7cb0f5c840b8a815aae7b1c8b244",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, max=1067), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "882069cbbd8d4e9d84c3b5d395c3032a",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, max=1334), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4b466ea3b7034eeab00e05d1e5329aed",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, max=120), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c515de2a39a14f2597c7e3e57e9cbef6",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, max=140), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"img_trainA = load_image_from_dir('./horse2zebra/trainA')\n",
"img_trainB = load_image_from_dir('./horse2zebra/trainB')\n",
"img_testA = load_image_from_dir('./horse2zebra/testA')\n",
"img_testB = load_image_from_dir('./horse2zebra/testB')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"s = tf.InteractiveSession()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Discriminator\n",
"双方の画像群から生成された画像を識別するので2つ生成します。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from keras.initializers import RandomNormal\n",
"gamma_init = RandomNormal(1.0,0.02)\n",
"conv_init = RandomNormal(0.0,0.02)\n",
"\n",
"def c2d(dim,k=3,*args, **kwargs):\n",
" return L.Conv2D(dim,k,kernel_initializer=conv_init,*args,**kwargs)\n",
"def c2d_t(dim,k=3,*args, **kwargs):\n",
" return L.Conv2DTranspose(dim,k,kernel_initializer=conv_init,*args,**kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"IMG_SHAPE = img_trainA.shape[1:]\n",
"\n",
"\n",
"def comp(n,*a,**k):\n",
" return c2d(n,4,strides=2,padding='same',*a,**k)\n",
"\n",
"def discriminator(input_shape, ndf, itr=3, last_layer_activation=None):\n",
" input = L.Input(input_shape)\n",
" \n",
" _ = comp(ndf)(input)\n",
" _ = L.LeakyReLU(alpha=0.2)(_)\n",
" for i in range(itr):\n",
" n = ndf * 2 ** (i+1)\n",
" _ = comp(n,use_bias=False)(_)\n",
" _ = norm()(_)\n",
" _ = L.LeakyReLU(alpha=0.2)(_)\n",
" n = ndf * 2 ** itr\n",
" _ = L.ZeroPadding2D((1,1))(_)\n",
" _ = c2d(n,4,use_bias = False)(_)\n",
" _ = norm()(_)\n",
" _ = L.LeakyReLU(alpha=0.2)(_)\n",
" _ = L.ZeroPadding2D((1,1))(_)\n",
" _ = c2d(1,4,activation=last_layer_activation)(_)\n",
" disc = Model(inputs=[input], outputs=_)\n",
" return disc"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_1 (InputLayer) (None, 256, 256, 3) 0 \n",
"_________________________________________________________________\n",
"conv2d_1 (Conv2D) (None, 128, 128, 32) 1568 \n",
"_________________________________________________________________\n",
"leaky_re_lu_1 (LeakyReLU) (None, 128, 128, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_2 (Conv2D) (None, 64, 64, 64) 32768 \n",
"_________________________________________________________________\n",
"instance_normalization2d_1 ( (None, 64, 64, 64) 128 \n",
"_________________________________________________________________\n",
"leaky_re_lu_2 (LeakyReLU) (None, 64, 64, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_3 (Conv2D) (None, 32, 32, 128) 131072 \n",
"_________________________________________________________________\n",
"instance_normalization2d_2 ( (None, 32, 32, 128) 256 \n",
"_________________________________________________________________\n",
"leaky_re_lu_3 (LeakyReLU) (None, 32, 32, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_4 (Conv2D) (None, 16, 16, 256) 524288 \n",
"_________________________________________________________________\n",
"instance_normalization2d_3 ( (None, 16, 16, 256) 512 \n",
"_________________________________________________________________\n",
"leaky_re_lu_4 (LeakyReLU) (None, 16, 16, 256) 0 \n",
"_________________________________________________________________\n",
"zero_padding2d_1 (ZeroPaddin (None, 18, 18, 256) 0 \n",
"_________________________________________________________________\n",
"conv2d_5 (Conv2D) (None, 15, 15, 256) 1048576 \n",
"_________________________________________________________________\n",
"instance_normalization2d_4 ( (None, 15, 15, 256) 512 \n",
"_________________________________________________________________\n",
"leaky_re_lu_5 (LeakyReLU) (None, 15, 15, 256) 0 \n",
"_________________________________________________________________\n",
"zero_padding2d_2 (ZeroPaddin (None, 17, 17, 256) 0 \n",
"_________________________________________________________________\n",
"conv2d_6 (Conv2D) (None, 14, 14, 1) 4097 \n",
"=================================================================\n",
"Total params: 1,743,777\n",
"Trainable params: 1,743,777\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"D_a = discriminator(IMG_SHAPE,32)\n",
"D_b = discriminator(IMG_SHAPE,32)\n",
"D_a.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generator\n",
"ある程度畳み込みをかけた後、ResNet(https://arxiv.org/abs/1512.03385 )を通し、逆畳込みで戻していきました。"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def resnet_block(input, dim,*args, **kwargs):\n",
" _ = c2d(dim,padding='same',*args, **kwargs)(input)\n",
" _ = norm()(_)\n",
" _ = L.Activation('relu')(_)\n",
" _ = c2d(dim,padding='same',*args, **kwargs)(_)\n",
" _ = norm()(_)\n",
" res = L.add([input,_])\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def generator(input_size, ndf, n_resnet, itr=4):\n",
" input = L.Input(input_size)\n",
" \n",
" _ = L.ZeroPadding2D((3,3))(input)\n",
" _ = c2d(ndf,7)(_)\n",
" _ = norm()(_)\n",
" _ = L.Activation('relu')(_)\n",
" for i in range(itr):\n",
" n = ndf * 2 **(i + 1)\n",
" _ = c2d(n,strides=2,padding='same')(_)\n",
" _ = norm()(_)\n",
" _ = L.Activation('relu')(_)\n",
" n = ndf * 2 **(itr)\n",
" for i in range(n_resnet):\n",
" _ = resnet_block(_, n)\n",
" for i in range(itr):\n",
" n = n//2\n",
" _ = c2d_t(n,strides=2,padding='same')(_)\n",
" _ = norm()(_)\n",
" _ = L.Activation('relu')(_)\n",
" _ = L.ZeroPadding2D((3,3))(_)\n",
" _ = c2d(3,7)(_)\n",
" _ = L.Activation('tanh')(_)\n",
" generator = Model(inputs=[input],outputs=_)\n",
" return generator"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"____________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"====================================================================================================\n",
"input_4 (InputLayer) (None, 256, 256, 3) 0 \n",
"____________________________________________________________________________________________________\n",
"zero_padding2d_7 (ZeroPadding2D) (None, 262, 262, 3) 0 input_4[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_37 (Conv2D) (None, 256, 256, 8) 1184 zero_padding2d_7[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_36 (Ins (None, 256, 256, 8) 16 conv2d_37[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_20 (Activation) (None, 256, 256, 8) 0 instance_normalization2d_36[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_38 (Conv2D) (None, 128, 128, 16) 1168 activation_20[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_37 (Ins (None, 128, 128, 16) 32 conv2d_38[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_21 (Activation) (None, 128, 128, 16) 0 instance_normalization2d_37[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_39 (Conv2D) (None, 64, 64, 32) 4640 activation_21[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_38 (Ins (None, 64, 64, 32) 64 conv2d_39[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_22 (Activation) (None, 64, 64, 32) 0 instance_normalization2d_38[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_40 (Conv2D) (None, 32, 32, 64) 18496 activation_22[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_39 (Ins (None, 32, 32, 64) 128 conv2d_40[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_23 (Activation) (None, 32, 32, 64) 0 instance_normalization2d_39[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_41 (Conv2D) (None, 16, 16, 128) 73856 activation_23[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_40 (Ins (None, 16, 16, 128) 256 conv2d_41[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_24 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_40[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_42 (Conv2D) (None, 16, 16, 128) 147584 activation_24[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_41 (Ins (None, 16, 16, 128) 256 conv2d_42[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_25 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_41[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_43 (Conv2D) (None, 16, 16, 128) 147584 activation_25[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_42 (Ins (None, 16, 16, 128) 256 conv2d_43[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_10 (Add) (None, 16, 16, 128) 0 activation_24[0][0] \n",
" instance_normalization2d_42[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_44 (Conv2D) (None, 16, 16, 128) 147584 add_10[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_43 (Ins (None, 16, 16, 128) 256 conv2d_44[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_26 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_43[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_45 (Conv2D) (None, 16, 16, 128) 147584 activation_26[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_44 (Ins (None, 16, 16, 128) 256 conv2d_45[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_11 (Add) (None, 16, 16, 128) 0 add_10[0][0] \n",
" instance_normalization2d_44[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_46 (Conv2D) (None, 16, 16, 128) 147584 add_11[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_45 (Ins (None, 16, 16, 128) 256 conv2d_46[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_27 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_45[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_47 (Conv2D) (None, 16, 16, 128) 147584 activation_27[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_46 (Ins (None, 16, 16, 128) 256 conv2d_47[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_12 (Add) (None, 16, 16, 128) 0 add_11[0][0] \n",
" instance_normalization2d_46[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_48 (Conv2D) (None, 16, 16, 128) 147584 add_12[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_47 (Ins (None, 16, 16, 128) 256 conv2d_48[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_28 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_47[0][0]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"____________________________________________________________________________________________________\n",
"conv2d_49 (Conv2D) (None, 16, 16, 128) 147584 activation_28[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_48 (Ins (None, 16, 16, 128) 256 conv2d_49[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_13 (Add) (None, 16, 16, 128) 0 add_12[0][0] \n",
" instance_normalization2d_48[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_50 (Conv2D) (None, 16, 16, 128) 147584 add_13[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_49 (Ins (None, 16, 16, 128) 256 conv2d_50[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_29 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_49[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_51 (Conv2D) (None, 16, 16, 128) 147584 activation_29[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_50 (Ins (None, 16, 16, 128) 256 conv2d_51[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_14 (Add) (None, 16, 16, 128) 0 add_13[0][0] \n",
" instance_normalization2d_50[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_52 (Conv2D) (None, 16, 16, 128) 147584 add_14[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_51 (Ins (None, 16, 16, 128) 256 conv2d_52[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_30 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_51[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_53 (Conv2D) (None, 16, 16, 128) 147584 activation_30[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_52 (Ins (None, 16, 16, 128) 256 conv2d_53[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_15 (Add) (None, 16, 16, 128) 0 add_14[0][0] \n",
" instance_normalization2d_52[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_54 (Conv2D) (None, 16, 16, 128) 147584 add_15[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_53 (Ins (None, 16, 16, 128) 256 conv2d_54[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_31 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_53[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_55 (Conv2D) (None, 16, 16, 128) 147584 activation_31[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_54 (Ins (None, 16, 16, 128) 256 conv2d_55[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_16 (Add) (None, 16, 16, 128) 0 add_15[0][0] \n",
" instance_normalization2d_54[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_56 (Conv2D) (None, 16, 16, 128) 147584 add_16[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_55 (Ins (None, 16, 16, 128) 256 conv2d_56[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_32 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_55[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_57 (Conv2D) (None, 16, 16, 128) 147584 activation_32[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_56 (Ins (None, 16, 16, 128) 256 conv2d_57[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_17 (Add) (None, 16, 16, 128) 0 add_16[0][0] \n",
" instance_normalization2d_56[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_58 (Conv2D) (None, 16, 16, 128) 147584 add_17[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_57 (Ins (None, 16, 16, 128) 256 conv2d_58[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_33 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_57[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_59 (Conv2D) (None, 16, 16, 128) 147584 activation_33[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_58 (Ins (None, 16, 16, 128) 256 conv2d_59[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_18 (Add) (None, 16, 16, 128) 0 add_17[0][0] \n",
" instance_normalization2d_58[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_transpose_5 (Conv2DTransp (None, 32, 32, 64) 73792 add_18[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_59 (Ins (None, 32, 32, 64) 128 conv2d_transpose_5[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_34 (Activation) (None, 32, 32, 64) 0 instance_normalization2d_59[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_transpose_6 (Conv2DTransp (None, 64, 64, 32) 18464 activation_34[0][0] \n",
"____________________________________________________________________________________________________\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"instance_normalization2d_60 (Ins (None, 64, 64, 32) 64 conv2d_transpose_6[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_35 (Activation) (None, 64, 64, 32) 0 instance_normalization2d_60[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_transpose_7 (Conv2DTransp (None, 128, 128, 16) 4624 activation_35[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_61 (Ins (None, 128, 128, 16) 32 conv2d_transpose_7[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_36 (Activation) (None, 128, 128, 16) 0 instance_normalization2d_61[0][0]\n",
"____________________________________________________________________________________________________\n",
"conv2d_transpose_8 (Conv2DTransp (None, 256, 256, 8) 1160 activation_36[0][0] \n",
"____________________________________________________________________________________________________\n",
"instance_normalization2d_62 (Ins (None, 256, 256, 8) 16 conv2d_transpose_8[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_37 (Activation) (None, 256, 256, 8) 0 instance_normalization2d_62[0][0]\n",
"____________________________________________________________________________________________________\n",
"zero_padding2d_8 (ZeroPadding2D) (None, 262, 262, 8) 0 activation_37[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_60 (Conv2D) (None, 256, 256, 3) 1179 zero_padding2d_8[0][0] \n",
"____________________________________________________________________________________________________\n",
"activation_38 (Activation) (None, 256, 256, 3) 0 conv2d_60[0][0] \n",
"====================================================================================================\n",
"Total params: 2,860,419\n",
"Trainable params: 2,860,419\n",
"Non-trainable params: 0\n",
"____________________________________________________________________________________________________\n"
]
}
],
"source": [
"G_b = generator(IMG_SHAPE,8,9)\n",
"G_a = generator(IMG_SHAPE,8,9)\n",
"G_a.summary()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import keras.optimizers as op"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"双方の誤差関数を出さなければならなく、どちらのものか判別しやすいよう、ヘルパー関数を定義した。\n",
"それでも混乱する。"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"mse_fn = lambda o, t: K.mean(K.abs(K.square(o - t)))\n",
"r_f_fn = lambda r, f: mse_fn(r,K.ones_like(r)) + mse_fn(f,K.zeros_like(f))\n",
"\n",
"def cycle_val(g1, g2):\n",
" r_in = g1.inputs[0]\n",
" f_out = g1.outputs[0]\n",
" c_out = g2([f_out])\n",
" return r_in, f_out, c_out\n",
"\n",
"r_a, f_b, c_a =cycle_val(G_b,G_a)\n",
"r_b, f_a, c_b =cycle_val(G_a,G_b)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"通常のGANの誤差に加え、生成器2つを回ってきた$ G_{b->a}(G_{a->b}(a)) $と元の画像$a$との誤差を加える。"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"LAMBDA = 10.0\n",
"\n",
"def cycle_loss(D, r, f, cyc):\n",
" out_r = D([r])\n",
" out_f = D([f])\n",
" loss_D = r_f_fn(out_r,out_f)\n",
" loss_G = mse_fn(out_f, K.ones_like(out_f))\n",
" loss_cyc = K.mean(K.abs(cyc - r))\n",
" return loss_D, loss_G, loss_cyc\n",
"\n",
"loss_DA, loss_GA, loss_CA = cycle_loss(D_a,r_a, f_a, c_a)\n",
"loss_DB, loss_GB, loss_CB = cycle_loss(D_b,r_b, f_b, c_b)\n",
"loss_cyc = loss_CA+loss_CB\n",
"\n",
"loss_G = loss_GA + loss_GB + LAMBDA * loss_cyc\n",
"loss_D = loss_DA + loss_DB"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"train_param_D = D_a.trainable_weights + D_b.trainable_weights\n",
"train_param_G = G_a.trainable_weights + G_b.trainable_weights"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"training_updates = op.Adam(0.0002, beta_1=0.5).get_updates(train_param_D,[],loss_D)\n",
"D_train = K.function([r_a,r_b], [loss_DA / 2, loss_DB / 2],training_updates)\n",
"training_updates = op.Adam(0.0002, beta_1=0.5).get_updates(train_param_G,[],loss_G)\n",
"G_train = K.function([r_a,r_b], [loss_GA , loss_GB, loss_cyc],training_updates)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def sample_images(data,gen,nrow,ncol):\n",
" if gen != None:\n",
" images = gen.predict(data)\n",
" else:\n",
" images = data \n",
" images = images.clip(0.0,1.0)\n",
" for i in range(nrow*ncol):\n",
" plt.subplot(nrow,ncol,i+1)\n",
" plt.imshow(images[i].reshape(IMG_SHAPE),cmap='gray', interpolation=\"none\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"s.run(tf.global_variables_initializer())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 学習\n",
"それなりのGPUを使っても、10時間くらい掛かりそう"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython import display\n",
"\n",
"EPOCHS = 90\n",
"UI_UPDATE = 50\n",
"\n",
"inv = 0\n",
"errs_DA = errs_DB = 0\n",
"errs_GA = errs_GB = 0\n",
"errs_C = 0\n",
"\n",
"for epoch in range(EPOCHS):\n",
" l = min(img_trainA.shape[0],img_trainB.shape[0])\n",
" batch = np.random.permutation(l)\n",
" for i in range(l):\n",
" in_A = img_trainA[i:i+1]\n",
" in_B = img_trainB[i:i+1]\n",
" assert(in_A.shape == in_B.shape)\n",
" err_DA, err_DB = D_train([in_A,in_B])\n",
" err_GA, err_GB, err_C = G_train([in_A,in_B])\n",
" errs_DA += err_DA\n",
" errs_DB += err_DB\n",
" errs_GA += err_GA\n",
" errs_GB += err_GB\n",
" errs_C += err_C\n",
" if inv%UI_UPDATE==0:\n",
" display.clear_output(wait=True)\n",
" select = np.random.permutation(l)\n",
" print('Epoch: %d/%d\\n batch: (%d/%d)' %(epoch, EPOCHS, i, l))\n",
" print('\\t| A\\t| B\\t|')\n",
" print('Dis\\t|%.4f\\t|%.4f\\t|'% (errs_DA/UI_UPDATE,errs_DB/UI_UPDATE))\n",
" print('Gen\\t|%.4f\\t|%.4f\\t|'% (errs_GA/UI_UPDATE,errs_GB/UI_UPDATE))\n",
" print('Cyc\\t|%.4f\\t|-------|\\n'% (LAMBDA * errs_C/UI_UPDATE))\n",
" print('A')\n",
" sample_images(img_trainA[select[:3]],None,1,3)\n",
" print('A -> B')\n",
" sample_images(img_trainA[select[:3]],G_b,1,3)\n",
" print('B')\n",
" sample_images(img_trainB[select[:3]],None,1,3)\n",
" print('B -> A')\n",
" sample_images(img_trainB[select[:3]],G_a,1,3)\n",
" errs_DA = errs_DB = 0\n",
" errs_GA = errs_GB = 0\n",
" errs_C = 0\n",
" inv += 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment