"cells": [
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cycle GAN\n",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import keras\n",
"import keras.backend as K\n",
"import keras.layers as L\n",
"from keras.models import Sequential,Model\n",
"from tqdm import tnrange\n",
"import tensorflow as tf\n",
"import tensorboard as tb\n",
"import cv2\n",
"import os"
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instance Normalization\n",
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from keras.engine.topology import Layer\n",
"class InstanceNormalization2D(Layer):\n",
" def __init__(self,\n",
" beta_initializer='zeros',\n",
" gamma_initializer='ones',\n",
" epsilon=1e-3,\n",
" **kwargs):\n",
" super(InstanceNormalization2D, self).__init__(**kwargs)\n",
" if K.image_data_format() is 'channels_first':\n",
" self.axis = 1\n",
" else:\n",
" self.axis = 3\n",
" self.epsilon = epsilon\n",
" self.beta_initializer = beta_initializer\n",
" self.gamma_initializer = gamma_initializer\n",
" def build(self, input_shape):\n",
" self.gamma = self.add_weight(shape=(input_shape[self.axis],),\n",
" initializer=self.gamma_initializer,\n",
" trainable=True,\n",
" name='gamma')\n",
" self.beta = self.add_weight(shape=(input_shape[self.axis],),\n",
" initializer=self.beta_initializer,\n",
" trainable=True,\n",
" name='beta')\n",
" super(InstanceNormalization2D, self).build(input_shape)\n",
" def call(self, x):\n",
" if K.image_data_format() is 'channels_first':\n",
" x_w, x_h = (2, 3)\n",
" else:\n",
" x_w, x_h = (1, 2)\n",
" hw = K.cast(K.shape(x)[x_h]* K.shape(x)[x_w], K.floatx())\n",
" mu = K.sum(x, axis=x_w)\n",
" mu = K.sum(mu, axis=x_h)\n",
" mu = mu / hw\n",
" mu = K.reshape(mu, (K.shape(mu)[0], K.shape(mu)[1], 1, 1))\n",
" sig2 = K.square(x - mu)\n",
" sig2 = K.sum(sig2, axis=x_w)\n",
" sig2 = K.sum(sig2, axis=x_h)\n",
" sig2 = K.reshape(sig2, (K.shape(sig2)[0], K.shape(sig2)[1], 1, 1))\n",
" y = (x - mu) / K.sqrt(sig2 + self.epsilon)\n",
" if K.image_data_format() is 'channels_first':\n",
" gamma = K.reshape(self.gamma, (1, K.shape(self.gamma)[0], 1, 1))\n",
" beta = K.reshape(self.beta, (1, K.shape(self.beta)[0], 1, 1))\n",
" else:\n",
" gamma = K.reshape(self.gamma, (1, 1, 1, K.shape(self.gamma)[0]))\n",
" beta = K.reshape(self.beta, (1, 1, 1, K.shape(self.beta)[0]))\n",
" return gamma * y + beta\n",
" \n",
"def norm(**kwargs):\n",
" return InstanceNormalization2D()"
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def load_image_from_dir(dir_name):\n",
" img_list = []\n",
" directory = os.listdir(dir_name)\n",
" for _,file in zip(tnrange(len(directory)),directory):\n",
" img = cv2.imread(dir_name + '/' +file)\n",
" img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)\n",
" img = img.astype(np.float32)/255.0\n",
" img_list.append(img)\n",
" return np.array(img_list)"
"cell_type": "markdown",
"metadata": {},
"source": [
"馬とシマウマの画像は より"
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"img_trainA = load_image_from_dir('./horse2zebra/trainA')\n",
"img_trainB = load_image_from_dir('./horse2zebra/trainB')\n",
"img_testA = load_image_from_dir('./horse2zebra/testA')\n",
"img_testB = load_image_from_dir('./horse2zebra/testB')"
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"s = tf.InteractiveSession()"
"cell_type": "markdown",
"metadata": {},
"source": [
"### Discriminator\n",
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from keras.initializers import RandomNormal\n",
"gamma_init = RandomNormal(1.0,0.02)\n",
"conv_init = RandomNormal(0.0,0.02)\n",
"def c2d(dim,k=3,*args, **kwargs):\n",
" return L.Conv2D(dim,k,kernel_initializer=conv_init,*args,**kwargs)\n",
"def c2d_t(dim,k=3,*args, **kwargs):\n",
" return L.Conv2DTranspose(dim,k,kernel_initializer=conv_init,*args,**kwargs)"
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"IMG_SHAPE = img_trainA.shape[1:]\n",
"def comp(n,*a,**k):\n",
" return c2d(n,4,strides=2,padding='same',*a,**k)\n",
"def discriminator(input_shape, ndf, itr=3, last_layer_activation=None):\n",
" input = L.Input(input_shape)\n",
" \n",
" _ = comp(ndf)(input)\n",
" _ = L.LeakyReLU(alpha=0.2)(_)\n",
" for i in range(itr):\n",
" n = ndf * 2 ** (i+1)\n",
" _ = comp(n,use_bias=False)(_)\n",
" _ = norm()(_)\n",
" _ = L.LeakyReLU(alpha=0.2)(_)\n",
" n = ndf * 2 ** itr\n",
" _ = L.ZeroPadding2D((1,1))(_)\n",
" _ = c2d(n,4,use_bias = False)(_)\n",
" _ = norm()(_)\n",
" _ = L.LeakyReLU(alpha=0.2)(_)\n",
" _ = L.ZeroPadding2D((1,1))(_)\n",
" _ = c2d(1,4,activation=last_layer_activation)(_)\n",
" disc = Model(inputs=[input], outputs=_)\n",
" return disc"
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Layer (type) Output Shape Param # \n",
"input_1 (InputLayer) (None, 256, 256, 3) 0 \n",
"conv2d_1 (Conv2D) (None, 128, 128, 32) 1568 \n",
"leaky_re_lu_1 (LeakyReLU) (None, 128, 128, 32) 0 \n",
"conv2d_2 (Conv2D) (None, 64, 64, 64) 32768 \n",
"instance_normalization2d_1 ( (None, 64, 64, 64) 128 \n",
"leaky_re_lu_2 (LeakyReLU) (None, 64, 64, 64) 0 \n",
"conv2d_3 (Conv2D) (None, 32, 32, 128) 131072 \n",
"instance_normalization2d_2 ( (None, 32, 32, 128) 256 \n",
"leaky_re_lu_3 (LeakyReLU) (None, 32, 32, 128) 0 \n",
"conv2d_4 (Conv2D) (None, 16, 16, 256) 524288 \n",
"instance_normalization2d_3 ( (None, 16, 16, 256) 512 \n",
"leaky_re_lu_4 (LeakyReLU) (None, 16, 16, 256) 0 \n",
"zero_padding2d_1 (ZeroPaddin (None, 18, 18, 256) 0 \n",
"conv2d_5 (Conv2D) (None, 15, 15, 256) 1048576 \n",
"instance_normalization2d_4 ( (None, 15, 15, 256) 512 \n",
"leaky_re_lu_5 (LeakyReLU) (None, 15, 15, 256) 0 \n",
"zero_padding2d_2 (ZeroPaddin (None, 17, 17, 256) 0 \n",
"conv2d_6 (Conv2D) (None, 14, 14, 1) 4097 \n",
"Total params: 1,743,777\n",
"Trainable params: 1,743,777\n",
"Non-trainable params: 0\n",
"source": [
"D_a = discriminator(IMG_SHAPE,32)\n",
"D_b = discriminator(IMG_SHAPE,32)\n",
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generator\n",
"ある程度畳み込みをかけた後、ResNet( )を通し、逆畳込みで戻していきました。"
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def resnet_block(input, dim,*args, **kwargs):\n",
" _ = c2d(dim,padding='same',*args, **kwargs)(input)\n",
" _ = norm()(_)\n",
" _ = L.Activation('relu')(_)\n",
" _ = c2d(dim,padding='same',*args, **kwargs)(_)\n",
" _ = norm()(_)\n",
" res = L.add([input,_])\n",
" return res"
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def generator(input_size, ndf, n_resnet, itr=4):\n",
" input = L.Input(input_size)\n",
" \n",
" _ = L.ZeroPadding2D((3,3))(input)\n",
" _ = c2d(ndf,7)(_)\n",
" _ = norm()(_)\n",
" _ = L.Activation('relu')(_)\n",
" for i in range(itr):\n",
" n = ndf * 2 **(i + 1)\n",
" _ = c2d(n,strides=2,padding='same')(_)\n",
" _ = norm()(_)\n",
" _ = L.Activation('relu')(_)\n",
" n = ndf * 2 **(itr)\n",
" for i in range(n_resnet):\n",
" _ = resnet_block(_, n)\n",
" for i in range(itr):\n",
" n = n//2\n",
" _ = c2d_t(n,strides=2,padding='same')(_)\n",
" _ = norm()(_)\n",
" _ = L.Activation('relu')(_)\n",
" _ = L.ZeroPadding2D((3,3))(_)\n",
" _ = c2d(3,7)(_)\n",
" _ = L.Activation('tanh')(_)\n",
" generator = Model(inputs=[input],outputs=_)\n",
" return generator"
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Layer (type) Output Shape Param # Connected to \n",
"input_4 (InputLayer) (None, 256, 256, 3) 0 \n",
"zero_padding2d_7 (ZeroPadding2D) (None, 262, 262, 3) 0 input_4[0][0] \n",
"conv2d_37 (Conv2D) (None, 256, 256, 8) 1184 zero_padding2d_7[0][0] \n",
"instance_normalization2d_36 (Ins (None, 256, 256, 8) 16 conv2d_37[0][0] \n",
"activation_20 (Activation) (None, 256, 256, 8) 0 instance_normalization2d_36[0][0]\n",
"conv2d_38 (Conv2D) (None, 128, 128, 16) 1168 activation_20[0][0] \n",
"instance_normalization2d_37 (Ins (None, 128, 128, 16) 32 conv2d_38[0][0] \n",
"activation_21 (Activation) (None, 128, 128, 16) 0 instance_normalization2d_37[0][0]\n",
"conv2d_39 (Conv2D) (None, 64, 64, 32) 4640 activation_21[0][0] \n",
"instance_normalization2d_38 (Ins (None, 64, 64, 32) 64 conv2d_39[0][0] \n",
"activation_22 (Activation) (None, 64, 64, 32) 0 instance_normalization2d_38[0][0]\n",
"conv2d_40 (Conv2D) (None, 32, 32, 64) 18496 activation_22[0][0] \n",
"instance_normalization2d_39 (Ins (None, 32, 32, 64) 128 conv2d_40[0][0] \n",
"activation_23 (Activation) (None, 32, 32, 64) 0 instance_normalization2d_39[0][0]\n",
"conv2d_41 (Conv2D) (None, 16, 16, 128) 73856 activation_23[0][0] \n",
"instance_normalization2d_40 (Ins (None, 16, 16, 128) 256 conv2d_41[0][0] \n",
"activation_24 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_40[0][0]\n",
"conv2d_42 (Conv2D) (None, 16, 16, 128) 147584 activation_24[0][0] \n",
"instance_normalization2d_41 (Ins (None, 16, 16, 128) 256 conv2d_42[0][0] \n",
"activation_25 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_41[0][0]\n",
"conv2d_43 (Conv2D) (None, 16, 16, 128) 147584 activation_25[0][0] \n",
"instance_normalization2d_42 (Ins (None, 16, 16, 128) 256 conv2d_43[0][0] \n",
"add_10 (Add) (None, 16, 16, 128) 0 activation_24[0][0] \n",
" instance_normalization2d_42[0][0]\n",
"conv2d_44 (Conv2D) (None, 16, 16, 128) 147584 add_10[0][0] \n",
"instance_normalization2d_43 (Ins (None, 16, 16, 128) 256 conv2d_44[0][0] \n",
"activation_26 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_43[0][0]\n",
"conv2d_45 (Conv2D) (None, 16, 16, 128) 147584 activation_26[0][0] \n",
"instance_normalization2d_44 (Ins (None, 16, 16, 128) 256 conv2d_45[0][0] \n",
"add_11 (Add) (None, 16, 16, 128) 0 add_10[0][0] \n",
" instance_normalization2d_44[0][0]\n",
"conv2d_46 (Conv2D) (None, 16, 16, 128) 147584 add_11[0][0] \n",
"instance_normalization2d_45 (Ins (None, 16, 16, 128) 256 conv2d_46[0][0] \n",
"activation_27 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_45[0][0]\n",
"conv2d_47 (Conv2D) (None, 16, 16, 128) 147584 activation_27[0][0] \n",
"instance_normalization2d_46 (Ins (None, 16, 16, 128) 256 conv2d_47[0][0] \n",
"add_12 (Add) (None, 16, 16, 128) 0 add_11[0][0] \n",
" instance_normalization2d_46[0][0]\n",
"conv2d_48 (Conv2D) (None, 16, 16, 128) 147584 add_12[0][0] \n",
"instance_normalization2d_47 (Ins (None, 16, 16, 128) 256 conv2d_48[0][0] \n",
"activation_28 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_47[0][0]\n"
"name": "stdout",
"output_type": "stream",
"text": [
"conv2d_49 (Conv2D) (None, 16, 16, 128) 147584 activation_28[0][0] \n",
"instance_normalization2d_48 (Ins (None, 16, 16, 128) 256 conv2d_49[0][0] \n",
"add_13 (Add) (None, 16, 16, 128) 0 add_12[0][0] \n",
" instance_normalization2d_48[0][0]\n",
"conv2d_50 (Conv2D) (None, 16, 16, 128) 147584 add_13[0][0] \n",
"instance_normalization2d_49 (Ins (None, 16, 16, 128) 256 conv2d_50[0][0] \n",
"activation_29 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_49[0][0]\n",
"conv2d_51 (Conv2D) (None, 16, 16, 128) 147584 activation_29[0][0] \n",
"instance_normalization2d_50 (Ins (None, 16, 16, 128) 256 conv2d_51[0][0] \n",
"add_14 (Add) (None, 16, 16, 128) 0 add_13[0][0] \n",
" instance_normalization2d_50[0][0]\n",
"conv2d_52 (Conv2D) (None, 16, 16, 128) 147584 add_14[0][0] \n",
"instance_normalization2d_51 (Ins (None, 16, 16, 128) 256 conv2d_52[0][0] \n",
"activation_30 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_51[0][0]\n",
"conv2d_53 (Conv2D) (None, 16, 16, 128) 147584 activation_30[0][0] \n",
"instance_normalization2d_52 (Ins (None, 16, 16, 128) 256 conv2d_53[0][0] \n",
"add_15 (Add) (None, 16, 16, 128) 0 add_14[0][0] \n",
" instance_normalization2d_52[0][0]\n",
"conv2d_54 (Conv2D) (None, 16, 16, 128) 147584 add_15[0][0] \n",
"instance_normalization2d_53 (Ins (None, 16, 16, 128) 256 conv2d_54[0][0] \n",
"activation_31 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_53[0][0]\n",
"conv2d_55 (Conv2D) (None, 16, 16, 128) 147584 activation_31[0][0] \n",
"instance_normalization2d_54 (Ins (None, 16, 16, 128) 256 conv2d_55[0][0] \n",
"add_16 (Add) (None, 16, 16, 128) 0 add_15[0][0] \n",
" instance_normalization2d_54[0][0]\n",
"conv2d_56 (Conv2D) (None, 16, 16, 128) 147584 add_16[0][0] \n",
"instance_normalization2d_55 (Ins (None, 16, 16, 128) 256 conv2d_56[0][0] \n",
"activation_32 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_55[0][0]\n",
"conv2d_57 (Conv2D) (None, 16, 16, 128) 147584 activation_32[0][0] \n",
"instance_normalization2d_56 (Ins (None, 16, 16, 128) 256 conv2d_57[0][0] \n",
"add_17 (Add) (None, 16, 16, 128) 0 add_16[0][0] \n",
" instance_normalization2d_56[0][0]\n",
"conv2d_58 (Conv2D) (None, 16, 16, 128) 147584 add_17[0][0] \n",
"instance_normalization2d_57 (Ins (None, 16, 16, 128) 256 conv2d_58[0][0] \n",
"activation_33 (Activation) (None, 16, 16, 128) 0 instance_normalization2d_57[0][0]\n",
"conv2d_59 (Conv2D) (None, 16, 16, 128) 147584 activation_33[0][0] \n",
"instance_normalization2d_58 (Ins (None, 16, 16, 128) 256 conv2d_59[0][0] \n",
"add_18 (Add) (None, 16, 16, 128) 0 add_17[0][0] \n",
" instance_normalization2d_58[0][0]\n",
"conv2d_transpose_5 (Conv2DTransp (None, 32, 32, 64) 73792 add_18[0][0] \n",
"instance_normalization2d_59 (Ins (None, 32, 32, 64) 128 conv2d_transpose_5[0][0] \n",
"activation_34 (Activation) (None, 32, 32, 64) 0 instance_normalization2d_59[0][0]\n",
"conv2d_transpose_6 (Conv2DTransp (None, 64, 64, 32) 18464 activation_34[0][0] \n",
"name": "stdout",
"output_type": "stream",
"text": [
"instance_normalization2d_60 (Ins (None, 64, 64, 32) 64 conv2d_transpose_6[0][0] \n",
"activation_35 (Activation) (None, 64, 64, 32) 0 instance_normalization2d_60[0][0]\n",
"conv2d_transpose_7 (Conv2DTransp (None, 128, 128, 16) 4624 activation_35[0][0] \n",
"instance_normalization2d_61 (Ins (None, 128, 128, 16) 32 conv2d_transpose_7[0][0] \n",
"activation_36 (Activation) (None, 128, 128, 16) 0 instance_normalization2d_61[0][0]\n",
"conv2d_transpose_8 (Conv2DTransp (None, 256, 256, 8) 1160 activation_36[0][0] \n",
"instance_normalization2d_62 (Ins (None, 256, 256, 8) 16 conv2d_transpose_8[0][0] \n",
"activation_37 (Activation) (None, 256, 256, 8) 0 instance_normalization2d_62[0][0]\n",
"zero_padding2d_8 (ZeroPadding2D) (None, 262, 262, 8) 0 activation_37[0][0] \n",
"conv2d_60 (Conv2D) (None, 256, 256, 3) 1179 zero_padding2d_8[0][0] \n",
"activation_38 (Activation) (None, 256, 256, 3) 0 conv2d_60[0][0] \n",
"Total params: 2,860,419\n",
"Trainable params: 2,860,419\n",
"Non-trainable params: 0\n",
"source": [
"G_b = generator(IMG_SHAPE,8,9)\n",
"G_a = generator(IMG_SHAPE,8,9)\n",
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import keras.optimizers as op"
"cell_type": "markdown",
"metadata": {},
"source": [
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"mse_fn = lambda o, t: K.mean(K.abs(K.square(o - t)))\n",
"r_f_fn = lambda r, f: mse_fn(r,K.ones_like(r)) + mse_fn(f,K.zeros_like(f))\n",
"def cycle_val(g1, g2):\n",
" r_in = g1.inputs[0]\n",
" f_out = g1.outputs[0]\n",
" c_out = g2([f_out])\n",
" return r_in, f_out, c_out\n",
"r_a, f_b, c_a =cycle_val(G_b,G_a)\n",
"r_b, f_a, c_b =cycle_val(G_a,G_b)"
"cell_type": "markdown",
"metadata": {},
"source": [
"通常のGANの誤差に加え、生成器2つを回ってきた$ G_{b->a}(G_{a->b}(a)) $と元の画像$a$との誤差を加える。"
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"LAMBDA = 10.0\n",
"def cycle_loss(D, r, f, cyc):\n",
" out_r = D([r])\n",
" out_f = D([f])\n",
" loss_D = r_f_fn(out_r,out_f)\n",
" loss_G = mse_fn(out_f, K.ones_like(out_f))\n",
" loss_cyc = K.mean(K.abs(cyc - r))\n",
" return loss_D, loss_G, loss_cyc\n",
"loss_DA, loss_GA, loss_CA = cycle_loss(D_a,r_a, f_a, c_a)\n",
"loss_DB, loss_GB, loss_CB = cycle_loss(D_b,r_b, f_b, c_b)\n",
"loss_cyc = loss_CA+loss_CB\n",
"loss_G = loss_GA + loss_GB + LAMBDA * loss_cyc\n",
"loss_D = loss_DA + loss_DB"
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"train_param_D = D_a.trainable_weights + D_b.trainable_weights\n",
"train_param_G = G_a.trainable_weights + G_b.trainable_weights"
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"training_updates = op.Adam(0.0002, beta_1=0.5).get_updates(train_param_D,[],loss_D)\n",
"D_train = K.function([r_a,r_b], [loss_DA / 2, loss_DB / 2],training_updates)\n",
"training_updates = op.Adam(0.0002, beta_1=0.5).get_updates(train_param_G,[],loss_G)\n",
"G_train = K.function([r_a,r_b], [loss_GA , loss_GB, loss_cyc],training_updates)"
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def sample_images(data,gen,nrow,ncol):\n",
" if gen != None:\n",
" images = gen.predict(data)\n",
" else:\n",
" images = data \n",
" images = images.clip(0.0,1.0)\n",
" for i in range(nrow*ncol):\n",
" plt.subplot(nrow,ncol,i+1)\n",
" plt.imshow(images[i].reshape(IMG_SHAPE),cmap='gray', interpolation=\"none\")\n",
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"cell_type": "markdown",
"metadata": {},
"source": [
"## 学習\n",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython import display\n",
"EPOCHS = 90\n",
"UI_UPDATE = 50\n",
"inv = 0\n",
"errs_DA = errs_DB = 0\n",
"errs_GA = errs_GB = 0\n",
"errs_C = 0\n",
"for epoch in range(EPOCHS):\n",
" l = min(img_trainA.shape[0],img_trainB.shape[0])\n",
" batch = np.random.permutation(l)\n",
" for i in range(l):\n",
" in_A = img_trainA[i:i+1]\n",
" in_B = img_trainB[i:i+1]\n",
" assert(in_A.shape == in_B.shape)\n",
" err_DA, err_DB = D_train([in_A,in_B])\n",
" err_GA, err_GB, err_C = G_train([in_A,in_B])\n",
" errs_DA += err_DA\n",
" errs_DB += err_DB\n",
" errs_GA += err_GA\n",
" errs_GB += err_GB\n",
" errs_C += err_C\n",
" if inv%UI_UPDATE==0:\n",
" display.clear_output(wait=True)\n",
" select = np.random.permutation(l)\n",
" print('Epoch: %d/%d\\n batch: (%d/%d)' %(epoch, EPOCHS, i, l))\n",
" print('\\t| A\\t| B\\t|')\n",
" print('Dis\\t|%.4f\\t|%.4f\\t|'% (errs_DA/UI_UPDATE,errs_DB/UI_UPDATE))\n",
" print('Gen\\t|%.4f\\t|%.4f\\t|'% (errs_GA/UI_UPDATE,errs_GB/UI_UPDATE))\n",
" print('Cyc\\t|%.4f\\t|-------|\\n'% (LAMBDA * errs_C/UI_UPDATE))\n",
" print('A')\n",
" sample_images(img_trainA[select[:3]],None,1,3)\n",
" print('A -> B')\n",
" sample_images(img_trainA[select[:3]],G_b,1,3)\n",
" print('B')\n",
" sample_images(img_trainB[select[:3]],None,1,3)\n",
" print('B -> A')\n",
" sample_images(img_trainB[select[:3]],G_a,1,3)\n",
" errs_DA = errs_DB = 0\n",
" errs_GA = errs_GB = 0\n",
" errs_C = 0\n",
" inv += 1"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
