Skip to content

Instantly share code, notes, and snippets.

@nb312
Last active May 5, 2023 00:42
Show Gist options
  • Save nb312/15d27c93c0fef5db7664142c294d50e4 to your computer and use it in GitHub Desktop.
Save nb312/15d27c93c0fef5db7664142c294d50e4 to your computer and use it in GitHub Desktop.
GAN的一个简单 例子
import numpy as np
import keras
from keras.models import Sequential, Model
from keras.layers import Dense, Input, LeakyReLU, Dropout
from keras.optimizers import Adam
# 超参数设置
latent_dim = 100
img_shape = (28, 28, 1)
# 构建生成器
def build_generator():
model = Sequential()
model.add(Dense(256, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(np.prod(img_shape), activation='tanh'))
return model
# 构建判别器
def build_discriminator():
model = Sequential()
model.add(Dense(512, input_dim=np.prod(img_shape)))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Dense(1, activation='sigmoid'))
return model
# 创建生成器和判别器
generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
# 构建组合模型
z = Input(shape=(latent_dim,))
img = generator(z)
discriminator.trainable = False
valid = discriminator(img)
combined = Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
# 准备训练数据(MNIST数据集)
from keras.datasets import mnist
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
# 训练GAN
batch_size = 128
epochs = 10000
for epoch in range(epochs):
# 随机选择一个真实图像批次
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
# 生成一个假图像批次
noise = np.random.normal(0, 1, (batch_size, latent_dim))
gen_imgs = generator.predict(noise)
# 训练判别器
d_loss_real = discriminator.train_on_batch(imgs, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1)))
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# 生成一个假图像批次
noise = np.random.normal(0, 1, (batch_size, latent_dim))
# 训练生成器(让判别器错误地将生成的图像判断为真实图像)
g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))
# 输出训练进度信息
print("Epoch %d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
# 每隔一定间隔保存生成的图像
if epoch % 1000 == 0:
save_imgs(epoch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment