Skip to content

Instantly share code, notes, and snippets.

@truongthanhdat
Last active August 25, 2018 04:38
Show Gist options
  • Save truongthanhdat/4c52f340044e3f4f773bae6c5124c750 to your computer and use it in GitHub Desktop.
Save truongthanhdat/4c52f340044e3f4f773bae6c5124c750 to your computer and use it in GitHub Desktop.

Generative Discriminative Networks - GANs

GANs là gì?

GANs là một thuật toán học không giám sát (Unsupersived Learning) được Ian Goodfellow �giới thiệu vào năm 2014 tại hội nghị NIPS, trong đó bao gồm hai thành phần chính là GeneratorDiscriminator:

  • Generator (ký hiệu $G$) nhận nhiệm vụ học ra cách áp xạ từ một không gian tìm ẩn $Z$ (a latent space) vào một không gian với phân phối từ dữ liệu �cho trước.

  • Discriminator (ký hiệu $D$) nhận nhiệm vụ phân biệt dữ liệu được tạo ra từ $G$ và dữ liệu cho trước.

Một cách toán học: Giả sử ta có $z \in Z$$z \sim p_Z(z)$, dữ liệu cho trước $x$$x \sim p_{data}(x)$ ($x$ gọi là real data). Ta có $G$ sẽ ánh xạ $z$ không gian dữ liệu cho trước $\hat{x}=G(z)$ ($\hat{x}$ gọi là fake data). $D(x)$ là xác suất mà $x$real data hay fake data. Mục tiêu của GANs là làm sao cho $G$ cố gắng tạo ra được $\hat{x}$ sao cho $D$ không còn thể phân biệt được là fake data. Tối ưu $G$$D$ giống như trò chơi $minimax$ với hàm mục tiêu $V(D, G)$, trong đó $G$ cố gắng làm tăng xác suất mà $\hat{x}$ được tạo ra là real data$D$ thì cố gắng làm điều ngược lại.

$$ \min_{G} \max_{D} V(D, G) = \mathbb{E}{x \sim p{data}(x)}[log D(x)] + \mathbb{E}_{z \sim p_Z(z)}[log(1 - D(G(z)))]$$

Tối ưu GANs

Qúa tình tối ưu GANs cũng khá đơn giản:

  1. Lấy ngẫu nhiên $m$ mẫu$z \in Z$ và $m$ mẫu $x$ từ dữ liệu cho trước.
  2. Tối ưu $D$ dựa trên $z$$x$.
  3. Lấy ngẫu nhiên $m$ mẫu $z \in Z$ (có thể dùng lại $z$ ở bước 1).
  4. Tối ưu $G$ dựa trên $z$.
  5. Quay lại bước $1$.

Minh hoạ bằng Python với Tensorflow

Sau đây mình sẽ minh hoạ GAN với tập dữ liệu MNIST. Toàn bộ mã nguồn có thể tìm được tại đây

Trước tiên mình cần cài đặt thư viện Tensorflow cho Python

[sudo] pip install tensorflow #Hoặc tensorflow-gpu đối với các bạn sử dụng Tensorflow với GPU

Trước tiên ta cần khai bái các thư viện cần thiết:

import tensorflow as tf
import tensorflow.contrib.slim as slim #slim cho phép khai báo nhanh các lớp thông dụng
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data #Tensorfow cung cấp sẵn API để lấy dữ liệu từ tập MNIST

Ta cần định nghĩa $G$$D$, ở đây mình sử dụng một mạng truyền thẳng (feed-forward) đơn giản với một lớp ẩn:

def generator(inputs):
	with tf.variable_scope("generator"):
		net = slim.fully_connected(inputs, 256, scope = "fc1")
		net = slim.fully_connected(net, 784, scope = "fake_images", activation_fn = tf.nn.sigmoid)
	return net

def discriminator(inputs):
	with tf.variable_scope("discriminator"):
		net = slim.fully_connected(inputs, 256, scope = "fc1")
		net = slim.fully_connected(net, 1, scope = "predictions", activation_fn = tf.nn.sigmoid)
	return net

Ta đinh nghĩa các hyperparameters cần thiết:

mnist_loader = input_data.read_data_sets('MNIST_data')
batch_size = 32
z_dim = 100 #Số chiều của Z
learning_rate = 0.0002
num_iters = 100000

Sau đó ta khởi tạo mạng:

random_z = tf.placeholder(shape = [batch_size, z_dim], dtype = tf.float32, name = "random_vector") #vector z
real_images = tf.placeholder(shape = [batch_size, 784], dtype = tf.float32, name = "real_images") #real data
fake_images = generator(random_z) #fake data

predictions = discriminator(tf.concat([real_images, fake_images], axis = 0)) #fake và real data được đưa qua Discriminator
real_preds = tf.slice(predictions, [0, 0], [batch_size, -1])
fake_preds = tf.slice(predictions, [batch_size, 0], [batch_size, -1])

Sau đó ta định nghĩa hàm mất cho GeneratorDiscriminator. Ở đây mình sử dụng Adam Optimizer để tối ưu $G$$D$.

gen_loss = -tf.reduce_mean(tf.log(fake_preds))
dis_loss = -tf.reduce_mean(tf.log(real_preds) + tf.log(1. - fake_preds))

gen_vars = slim.get_variables(scope = "generator")
dis_vars = slim.get_variables(scope = "discriminator")

optimizer = tf.train.AdamOptimizer(learning_rate)
gen_train_op = optimizer.minimize(gen_loss, var_list = gen_vars)
dis_train_op = optimizer.minimize(dis_loss, var_list = dis_vars)

Sau đó ta tiến hành tối ưu $G$$D$

sess = tf.Session()
sess.run(tf.global_variables_initializer())
for iter in xrange(1, num_iters + 1):
	## Vector ngẫu nhiên z được lấy từ phân phối đều trên [-1, 1]
    feed_dict = {
            random_z: np.random.uniform(-1., 1., size=[batch_size, z_dim]),
            real_images: mnist_loader.train.next_batch(batch_size=batch_size)[0]
            }
    _, _, _gen_loss, _dis_loss = sess.run(
		    [gen_train_op, dis_train_op, gen_loss, dis_loss],
			feed_dict = feed_dict
			)
			
    if (iter % 50) == 0:
        print("Iteration [{:06d}/{:06d}]".format(iter, num_iters))
        print("\t>> Generator Loss: {}".format(_gen_loss))
        print("\t>> Discriminator Loss: {}".format(_dis_loss))

Ứng dụng của GANs và Những điều lưu ý

Trong những năm gần đây, GAN đã có những ứng dụng mạnh mẽ trong nhiều bài toán như Image Super Resolution, Image Translation, Domain Adaptaion. Tuy nhiên để tối ưu GANs là điều không phải dễ, điều này đòi hỏi về phần cứng cũng như sự phân tích bài toán:

  • GAN đòi hỏi chi phí phần cứng cao với các bài toán xử lý ảnh có kích thước lớn.

  • Việc điều chỉnh Learning Rate cũng không phải là điểu dễ dàng, phải cân bằng làm sao cho DiscriminatorGenerator cân bằng lẫn nhau, nếu không dễ dẫn đến một bên lấn áp phần còn lại, dẫn đến Generator không cho ra kết quả tốt.

Bài viết này chỉ cung cấp một khái niệm và ví dụ cơ bản nhất về GANs, nếu có gì sai sót mong các bạn có thể đóng góp kiến.

Tham khảo

  1. Generative Adversarial Networks Wiki
  2. Generative Adversarial Nets. Ian Goodfellow. NIPS 2014
  3. Mã nguồn tham khảo
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
def generator(inputs):
with tf.variable_scope("generator"):
net = slim.fully_connected(inputs, 256, scope = "fc1")
net = slim.fully_connected(net, 784, scope = "fake_images", activation_fn = tf.nn.sigmoid)
return net
def discriminator(inputs):
with tf.variable_scope("discriminator"):
net = slim.fully_connected(inputs, 256, scope = "fc1")
net = slim.fully_connected(net, 1, scope = "predictions", activation_fn = tf.nn.sigmoid)
return net
if __name__ == "__main__":
mnist_loader = input_data.read_data_sets('MNIST_data')
batch_size = 32
z_dim = 100
learning_rate = 0.0002
num_iters = 100000
random_z = tf.placeholder(shape = [batch_size, z_dim], dtype = tf.float32, name = "random_vector")
real_images = tf.placeholder(shape = [batch_size, 784], dtype = tf.float32, name = "real_images")
fake_images = generator(random_z)
predictions = discriminator(tf.concat([real_images, fake_images], axis = 0))
real_preds = tf.slice(predictions, [0, 0], [batch_size, -1])
fake_preds = tf.slice(predictions, [batch_size, 0], [batch_size, -1])
gen_loss = -tf.reduce_mean(tf.log(fake_preds))
dis_loss = -tf.reduce_mean(tf.log(real_preds) + tf.log(1. - fake_preds))
gen_vars = slim.get_variables(scope = "generator")
dis_vars = slim.get_variables(scope = "discriminator")
optimizer = tf.train.AdamOptimizer(learning_rate)
gen_train_op = optimizer.minimize(gen_loss, var_list = gen_vars)
dis_train_op = optimizer.minimize(dis_loss, var_list = dis_vars)
summaries = [
tf.summary.scalar("gen_loss", gen_loss),
tf.summary.scalar("dis_loss", dis_loss),
tf.summary.image("real_images", tf.reshape(real_images, [batch_size, 28, 28, 1])),
tf.summary.image("fake_images", tf.reshape(fake_images, [batch_size, 28, 28, 1]))
]
summary_op = tf.summary.merge(summaries)
summary_writer = tf.summary.FileWriter("log", graph=tf.get_default_graph())
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for iter in xrange(1, num_iters + 1):
feed_dict = {
random_z: np.random.uniform(-1., 1., size=[batch_size, z_dim]),
real_images: mnist_loader.train.next_batch(batch_size=batch_size)[0]
}
_, _, _gen_loss, _dis_loss, summary = sess.run([gen_train_op, dis_train_op, gen_loss, dis_loss, summary_op],
feed_dict = feed_dict)
summary_writer.add_summary(summary, iter)
if (iter % 50) == 0:
print("Iteration [{:06d}/{:06d}]".format(iter, num_iters))
print("\t>> Generator Loss: {}".format(_gen_loss))
print("\t>> Discriminator Loss: {}".format(_dis_loss))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment