Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created February 10, 2019 18:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/df924f0f2f67a2a55ac326b96ea922e8 to your computer and use it in GitHub Desktop.
Save NMZivkovic/df924f0f2f67a2a55ac326b96ea922e8 to your computer and use it in GitHub Desktop.
def _build_and_compile_gan(self):
imageX = Input(shape=self.img_shape)
imageY = Input(shape=self.img_shape)
fakeY = self._generatorXY(imageX)
fakeX = self._generatorYX(imageY)
reconstructedX = self._generatorYX(fakeY)
reconstructedY = self._generatorXY(fakeX)
imageX_id = self._generatorYX(imageX)
imageY_id = self._generatorXY(imageY)
self._discriminatorX.trainable = False
self._discriminatorY.trainable = False
validX = self._discriminatorX(fakeX)
validY = self._discriminatorY(fakeY)
self.gan = Model(inputs=[imageX, imageY],
outputs=[ validX, validY,
reconstructedX, reconstructedY,
imageX_id, imageY_id ])
self.gan.compile(loss=['mse', 'mse',
'mae', 'mae',
'mae', 'mae'],
loss_weights=[ 1, 1,
self.cycle_lambda, self.cycle_lambda,
self.id_lambda, self.id_lambda ],
optimizer=self.optimizer)
self.gan.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment