Skip to content

Instantly share code, notes, and snippets.

@Elfsong
Last active January 25, 2019 09:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Elfsong/aa38a85fcb5c1d20ca781fe408d8c4ac to your computer and use it in GitHub Desktop.
Save Elfsong/aa38a85fcb5c1d20ca781fe408d8c4ac to your computer and use it in GitHub Desktop.
MNIST_GAN #GAN #python
import os # Get system method
import shutil # Recursive traversing file path
import tensorflow as tf # Neural Network2
import numpy as np # Matrix Computing
from skimage.io import imsave # Image operation
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets('MNIST_data/')
# The size of image is (28, 28, 1)
image_height = 28
image_width = 28
image_size = image_height * image_width
# Train/ Restore / Output Path
train = True
restore = False
output_path = "./output/"
# Hyper Parameters
max_epoch = 500
batch_size = 256
z_size = 128 # Input size
h1_size = 256 # Hidden layer 1 size
h2_size = 512 # Hidden layer 2 size
def load_data(data_path):
"""
Loading MNIST data
:param data_path: MNIST data path
:return train_data: (60000, 28, 28, 1)
:return train_label: (60000, 1)
"""
f_data = open(os.path.join(data_path, 'train-images.idx3-ubyte'))
loaded_data = np.fromfile(file=f_data, dtype=np.uint8)
# The first 16 bits should be skipped
train_data = loaded_data[16:].reshape((-1, 784)).astype(np.float)
f_label = open(os.path.join(data_path, 'train-labels.idx1-ubyte'))
loaded_label = np.fromfile(file=f_label, dtype=np.uint8)
# The first 8 bits should be skipped
train_label = loaded_label[8:].reshape((-1)).astype(np.float)
return train_data, train_label
def generator(z_prior):
"""
Generating image
:param z_prior: The input of random noise matrix (batch_size, z_size)
:return x_generate: The generated image
:return g_params: All parameters of the generator
"""
# The first hidden layer
w1 = tf.Variable(tf.truncated_normal([z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32)
b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32)
h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)
# The second hidden layer
w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32)
b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32)
h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)
# The third hidden layer
w3 = tf.Variable(tf.truncated_normal([h2_size, image_size], stddev=0.1), name="g_w3", dtype=tf.float32)
b3 = tf.Variable(tf.zeros([image_size]), name="g_b3", dtype=tf.float32)
x_generate = tf.nn.tanh(tf.matmul(h2, w3) + b3)
# All parameters of the generator
g_params = [w1, b1, w2, b2, w3, b3]
return x_generate, g_params
# 定义GAN的判别器
def discriminator(x_data, x_generated, keep_prob):
"""
Discriminating image
:param x_data: Real data
:param x_generated: Generated data
:param keep_prob: Dropout rate
:return y_data: result for real data
:return y_generated: result for generated data
:return d_params: All parameters of the discriminator
"""
# Merging the real data and the generated data
x_in = tf.concat([x_data, x_generated], 0)
# The first hidden layer
w1 = tf.Variable(tf.truncated_normal([image_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32)
b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)
h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob)
# The second hidden layer
w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)
b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)
h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob)
# The third hidden layer
w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)
b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)
h3 = tf.matmul(h2, w3) + b3
# Get batch_size images
y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))
# Get remained images
y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))
# All parameters of the discriminator
d_params = [w1, b1, w2, b2, w3, b3]
return y_data, y_generated, d_params
def show_result(batch_result, fname, grid_size=(8, 8), grid_pad=5):
"""
Showing the result
:param batch_result: Batch size image input
:param fname: input path
:param grid_size: Output image size (default 8*8)
:param grid_pad: Output padding (default 5 pixels)
:return: None
"""
# Regularisation / Reshape (batch_size, image_height, image_width)
batch_res = 0.5 * batch_result.reshape((batch_result.shape[0], image_height, image_width)) + 0.5
img_h, img_w = batch_res.shape[1], batch_res.shape[2]
grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)
grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)
img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
for i, res in enumerate(batch_res):
if i >= grid_size[0] * grid_size[1]:
break
img = (res) * 255.
img = img.astype(np.uint8)
row = (i // grid_size[0]) * (img_h + grid_pad)
col = (i % grid_size[1]) * (img_w + grid_pad)
img_grid[row:row + img_h, col:col + img_w] = img
imsave(fname, img_grid)
# 定义训练过程
def train():
'''
函数功能:训练整个GAN网络,并随机生成手写数字
输入:无
输出:sess.saver()
'''
# 加载数据
train_data, train_label = load_data("MNIST_data")
size = train_data.shape[0]
# 构建模型---------------------------------------------------------------------
# 定义GAN网络的输入,其中x_data为[batch_size, image_size], z_prior为[batch_size, z_size]
x_data = tf.placeholder(tf.float32, [batch_size, image_size], name="x_data") # (batch_size, image_size)
z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior") # (batch_size, z_size)
# 定义dropout率
keep_prob = tf.placeholder(tf.float32, name="keep_prob")
global_step = tf.Variable(0, name="global_step", trainable=False)
# 利用生成器生成数据x_generated和参数g_params
x_generated, g_params = generator(z_prior)
# 利用判别器判别生成器的结果
y_data, y_generated, d_params = discriminator(x_data, x_generated, keep_prob)
# 定义判别器和生成器的loss函数
d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))
g_loss = - tf.log(y_generated)
# 设置学习率为0.0001,用AdamOptimizer进行优化
optimizer = tf.train.AdamOptimizer(0.0001)
# 判别器discriminator 和生成器 generator 对损失函数进行最小化处理
d_trainer = optimizer.minimize(d_loss, var_list=d_params)
g_trainer = optimizer.minimize(g_loss, var_list=g_params)
# 模型构建完毕--------------------------------------------------------------------
# 全局变量初始化
init = tf.global_variables_initializer()
# 启动会话sess
saver = tf.train.Saver()
sess = tf.Session()
sess.run(init)
# 判断是否需要存储
if restore:
# 若是,将最近一次的checkpoint点存到outpath下
chkpt_fname = tf.train.latest_checkpoint(output_path)
saver.restore(sess, chkpt_fname)
else:
# 若否,判断目录是存在,如果目录存在,则递归的删除目录下的所有内容,并重新建立目录
if os.path.exists(output_path):
shutil.rmtree(output_path)
os.mkdir(output_path)
# 利用随机正态分布产生噪声影像,尺寸为(batch_size, z_size)
z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
# 逐个epoch内训练
for i in range(max_epoch):
# 图像每个epoch内可以放(size // batch_size)个size
for j in range(size // batch_size):
if j % 50 == 0:
print("epoch:%s, iter:%s" % (i, j))
# 训练一个batch的数据
batch_end = j * batch_size + batch_size
if batch_end >= size:
batch_end = size - 1
x_value = train_data[j * batch_size: batch_end]
# 将数据归一化到[-1, 1]
x_value = x_value / 255.
x_value = 2 * x_value - 1
# 以正太分布的形式产生随机噪声
z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
# 每个batch下,输入数据运行GAN,训练判别器
if j % 1 == 0:
sess.run(d_trainer,
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
# 每个batch下,输入数据运行GAN,训练生成器
if j % 1 == 0:
sess.run(g_trainer,
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
# 每一个epoch中的所有batch训练完后,利用z_sample测试训练后的生成器
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})
# 每一个epoch中的所有batch训练完后,显示生成器的结果,并打印生成结果的值
show_result(x_gen_val, os.path.join(output_path, "sample%s.jpg" % i))
print(x_gen_val)
# 每一个epoch中,生成随机分布以重置z_random_sample_val
z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
# 每一个epoch中,利用z_random_sample_val生成手写数字图像,并显示结果
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})
show_result(x_gen_val, os.path.join(output_path, "random_sample%s.jpg" % i))
# 保存会话
sess.run(tf.assign(global_step, i + 1))
saver.save(sess, os.path.join(output_path, "model"), global_step=global_step)
if __name__ == '__main__':
if train:
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment