Created
November 19, 2017 06:16
-
-
Save robertmaxwilliams/e1dfe63fa871afbb336ff423367b3f20 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/python | |
# -*- coding: utf8 -*- | |
""" GAN-CLS """ | |
import tensorflow as tf | |
import tensorlayer as tl | |
from tensorlayer.layers import * | |
from tensorlayer.prepro import * | |
from tensorlayer.cost import * | |
import numpy as np | |
import scipy | |
from scipy.io import loadmat | |
import time, os, re, nltk | |
from utils import * | |
from model import * | |
import model | |
import os | |
import re | |
import time | |
import nltk | |
import string | |
import tensorlayer as tl | |
from utils import * | |
dataset = '102flowers' | |
need_256 = True # set to True for stackGAN | |
if dataset == '102flowers': | |
""" | |
images.shape = [8000, 64, 64, 3] | |
captions_ids = [80000, any] | |
""" | |
cwd = os.getcwd() | |
img_dir = os.path.join(cwd, '102flowers') | |
caption_dir = os.path.join(cwd, 'text_c10') | |
VOC_FIR = cwd + '/vocab.txt' | |
## load captions | |
caption_sub_dir = load_folder_list( caption_dir ) | |
captions_dict = {} | |
processed_capts = [] | |
for sub_dir in caption_sub_dir: # get caption file list | |
with tl.ops.suppress_stdout(): | |
files = tl.files.load_file_list(path=sub_dir, regx='^image_[0-9]+\.txt') | |
for i, f in enumerate(files): | |
file_dir = os.path.join(sub_dir, f) | |
key = int(re.findall('\d+', f)[0]) | |
t = open(file_dir,'r') | |
lines = [] | |
for line in t: | |
line = preprocess_caption(line) | |
lines.append(line) | |
processed_capts.append(tl.nlp.process_sentence(line, start_word="<S>", end_word="</S>")) | |
assert len(lines) == 10, "Every flower image have 10 captions" | |
captions_dict[key] = lines | |
print(" * %d x %d captions found " % (len(captions_dict), len(lines))) | |
## build vocab | |
if not os.path.isfile('vocab.txt'): | |
_ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1) | |
else: | |
print("WARNING: vocab.txt already exists") | |
vocab = tl.nlp.Vocabulary(VOC_FIR, start_word="<S>", end_word="</S>", unk_word="<UNK>") | |
## store all captions ids in list | |
captions_ids = [] | |
try: # python3 | |
tmp = captions_dict.items() | |
except: # python3 | |
tmp = captions_dict.iteritems() | |
for key, value in tmp: | |
for v in value: | |
captions_ids.append( [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(v)] + [vocab.end_id]) # add END_ID | |
captions_ids = np.asarray(captions_ids) | |
print(" * tokenized %d captions" % len(captions_ids)) | |
## check | |
img_capt = captions_dict[1][1] | |
print("img_capt: %s" % img_capt) | |
print("nltk.tokenize.word_tokenize(img_capt): %s" % nltk.tokenize.word_tokenize(img_capt)) | |
img_capt_ids = [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(img_capt)]#img_capt.split(' ')] | |
print("img_capt_ids: %s" % img_capt_ids) | |
print("id_to_word: %s" % [vocab.id_to_word(id) for id in img_capt_ids]) | |
## load images | |
with tl.ops.suppress_stdout(): # get image files list | |
imgs_title_list = sorted(tl.files.load_file_list(path=img_dir, regx='^image_[0-9]+\.jpg')) | |
print(" * %d images found, start loading and resizing ..." % len(imgs_title_list)) | |
s = time.time() | |
images = [] | |
images_256 = [] | |
for name in imgs_title_list: | |
img_raw = scipy.misc.imread( os.path.join(img_dir, name) ) | |
img = tl.prepro.imresize(img_raw, size=[64, 64]) # (64, 64, 3) | |
img = img.astype(np.float32) | |
images.append(img) | |
if need_256: | |
img = tl.prepro.imresize(img_raw, size=[256, 256]) # (256, 256, 3) | |
img = img.astype(np.float32) | |
images_256.append(img) | |
print(" * loading and resizing took %ss" % (time.time()-s)) | |
n_images = len(captions_dict) | |
n_captions = len(captions_ids) | |
n_captions_per_image = len(lines) # 10 | |
print("n_captions: %d n_images: %d n_captions_per_image: %d" % (n_captions, n_images, n_captions_per_image)) | |
captions_ids_train, captions_ids_test = captions_ids[: 8000*n_captions_per_image], captions_ids[8000*n_captions_per_image :] | |
images_train, images_test = images[:8000], images[8000:] | |
if need_256: | |
images_train_256, images_test_256 = images_256[:8000], images_256[8000:] | |
n_images_train = len(images_train) | |
n_images_test = len(images_test) | |
n_captions_train = len(captions_ids_train) | |
n_captions_test = len(captions_ids_test) | |
print("n_images_train:%d n_captions_train:%d" % (n_images_train, n_captions_train)) | |
print("n_images_test:%d n_captions_test:%d" % (n_images_test, n_captions_test)) | |
if dataset=='kickstarter': | |
cwd = os.cwd() | |
data_dir = ps.path.join(cwd, 'kickstarter') | |
# list of tuples, format as (caption, id) | |
captions = list() | |
ids = list | |
##load captions | |
## put caption line in dict, key is filename minus .txt | |
for filename in os.listdir(data_dir): | |
if filename.endswith('.txt'): | |
with open(filename) as myfile: | |
captions.append(myfile.readlines()[2].strip()) | |
ids.append(filename[-4]) | |
# TODO create vocab.txt if it no exist | |
pairs = Counter(' '.join(captions).split()).items() | |
pairs.sort(lambda word,count: count) | |
with open("vocab.txt", 'w') as f: | |
for word,count in pairs: | |
f.write(word + ' ' + str(count) | |
# use vocab.txt to tokenize captions and convert to array of tokenized sentences | |
caption_ids = list() | |
vocab = tl.nlp.Vocabulary(VOC_FIR, start_word="<S>", end_word="</S>", unk_word="<UNK>") | |
for caption in captions: | |
captions_ids.append( [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(captions)] + [vocab.end_id]) # add END_ID | |
captions_ids = np.asarray(captions_ids) | |
# use file id (text filenames without extenions) to open all images using tl | |
images = list() | |
images_256 = list() | |
for idd in ids: | |
img_raw = scipy.misc.imread( os.path.join(img_dir, idd) ) | |
img = tl.prepro.imresize(img_raw, size=[64, 64]) # (64, 64, 3) | |
img = img.astype(np.float32) | |
images.append(img) | |
# also get them in 256 size | |
img = tl.prepro.imresize(img_raw, size=[256, 256]) # (256, 256, 3) | |
img = img.astype(np.float32) | |
images_256.append(img) | |
n_images = len(images) | |
n_captions = len(captions) | |
print("Number of images and captions: ", n_image, n_captions) | |
n_captions_per_image = 1 | |
# split into test and train (test will be first 4000 samples) | |
images_test, images_train = images[0:4000], images[4000:] | |
caption_ids_test, caption_ids_train = caption_ids[0:4000], caption_ids[4000:] | |
images_train = np.array(images_train) | |
images_test = np.array(images_test) | |
ni = int(np.ceil(np.sqrt(batch_size))) | |
tl.files.exists_or_mkdir("samples/step1_gan-cls") | |
tl.files.exists_or_mkdir("samples/step_pretrain_encoder") | |
tl.files.exists_or_mkdir("checkpoint") | |
save_dir = "checkpoint" | |
def main_train(): | |
###======================== DEFIINE MODEL ===================================### | |
t_real_image = tf.placeholder('float32', [batch_size, image_size, image_size, 3], name = 'real_image') | |
t_wrong_image = tf.placeholder('float32', [batch_size ,image_size, image_size, 3], name = 'wrong_image') | |
t_real_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='real_caption_input') | |
t_wrong_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='wrong_caption_input') | |
t_z = tf.placeholder(tf.float32, [batch_size, z_dim], name='z_noise') | |
## training inference for text-to-image mapping | |
net_cnn = cnn_encoder(t_real_image, is_train=True, reuse=False) | |
x = net_cnn.outputs | |
v = rnn_embed(t_real_caption, is_train=True, reuse=False).outputs | |
x_w = cnn_encoder(t_wrong_image, is_train=True, reuse=True).outputs | |
v_w = rnn_embed(t_wrong_caption, is_train=True, reuse=True).outputs | |
alpha = 0.2 # margin alpha | |
rnn_loss = tf.reduce_mean(tf.maximum(0., alpha - cosine_similarity(x, v) + cosine_similarity(x, v_w))) + \ | |
tf.reduce_mean(tf.maximum(0., alpha - cosine_similarity(x, v) + cosine_similarity(x_w, v))) | |
## training inference for txt2img | |
generator_txt2img = model.generator_txt2img_resnet | |
discriminator_txt2img = model.discriminator_txt2img_resnet | |
net_rnn = rnn_embed(t_real_caption, is_train=False, reuse=True) | |
net_fake_image, _ = generator_txt2img(t_z, | |
net_rnn.outputs, | |
is_train=True, reuse=False, batch_size=batch_size) | |
net_d, disc_fake_image_logits = discriminator_txt2img( | |
net_fake_image.outputs, net_rnn.outputs, is_train=True, reuse=False) | |
_, disc_real_image_logits = discriminator_txt2img( | |
t_real_image, net_rnn.outputs, is_train=True, reuse=True) | |
_, disc_mismatch_logits = discriminator_txt2img( | |
t_real_image, | |
rnn_embed(t_wrong_caption, is_train=False, reuse=True).outputs, | |
is_train=True, reuse=True) | |
## testing inference for txt2img | |
net_g, _ = generator_txt2img(t_z, | |
rnn_embed(t_real_caption, is_train=False, reuse=True).outputs, | |
is_train=False, reuse=True, batch_size=batch_size) | |
d_loss1 = tl.cost.sigmoid_cross_entropy(disc_real_image_logits, tf.ones_like(disc_real_image_logits), name='d1') | |
d_loss2 = tl.cost.sigmoid_cross_entropy(disc_mismatch_logits, tf.zeros_like(disc_mismatch_logits), name='d2') | |
d_loss3 = tl.cost.sigmoid_cross_entropy(disc_fake_image_logits, tf.zeros_like(disc_fake_image_logits), name='d3') | |
d_loss = d_loss1 + (d_loss2 + d_loss3) * 0.5 | |
g_loss = tl.cost.sigmoid_cross_entropy(disc_fake_image_logits, tf.ones_like(disc_fake_image_logits), name='g') | |
####======================== DEFINE TRAIN OPTS ==============================### | |
lr = 0.0002 | |
lr_decay = 0.5 # decay factor for adam, https://github.com/reedscot/icml2016/blob/master/main_cls_int.lua https://github.com/reedscot/icml2016/blob/master/scripts/train_flowers.sh | |
decay_every = 100 # https://github.com/reedscot/icml2016/blob/master/main_cls.lua | |
beta1 = 0.5 | |
cnn_vars = tl.layers.get_variables_with_name('cnn', True, True) | |
rnn_vars = tl.layers.get_variables_with_name('rnn', True, True) | |
d_vars = tl.layers.get_variables_with_name('discriminator', True, True) | |
g_vars = tl.layers.get_variables_with_name('generator', True, True) | |
with tf.variable_scope('learning_rate'): | |
lr_v = tf.Variable(lr, trainable=False) | |
d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars ) | |
g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars ) | |
grads, _ = tf.clip_by_global_norm(tf.gradients(rnn_loss, rnn_vars + cnn_vars), 10) | |
optimizer = tf.train.AdamOptimizer(lr_v, beta1=beta1)# optimizer = tf.train.GradientDescentOptimizer(lre) | |
rnn_optim = optimizer.apply_gradients(zip(grads, rnn_vars + cnn_vars)) | |
###============================ TRAINING ====================================### | |
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) | |
tl.layers.initialize_global_variables(sess) | |
# load the latest checkpoints | |
net_rnn_name = os.path.join(save_dir, 'net_rnn.npz') | |
net_cnn_name = os.path.join(save_dir, 'net_cnn.npz') | |
net_g_name = os.path.join(save_dir, 'net_g.npz') | |
net_d_name = os.path.join(save_dir, 'net_d.npz') | |
load_and_assign_npz(sess=sess, name=net_rnn_name, model=net_rnn) | |
load_and_assign_npz(sess=sess, name=net_cnn_name, model=net_cnn) | |
load_and_assign_npz(sess=sess, name=net_g_name, model=net_g) | |
load_and_assign_npz(sess=sess, name=net_d_name, model=net_d) | |
## seed for generation, z and sentence ids | |
sample_size = batch_size | |
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32) | |
sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."] * int(sample_size/ni) + \ | |
["this flower has petals that are yellow, white and purple and has dark lines"] * int(sample_size/ni) + \ | |
["the petals on this flower are white with a yellow center"] * int(sample_size/ni) + \ | |
["this flower has a lot of small round pink petals."] * int(sample_size/ni) + \ | |
["this flower is orange in color, and has petals that are ruffled and rounded."] * int(sample_size/ni) + \ | |
["the flower has yellow petals and the center of it is brown."] * int(sample_size/ni) + \ | |
["this flower has petals that are blue and white."] * int(sample_size/ni) +\ | |
["these white flowers have petals that start off white in color and end in a white towards the tips."] * int(sample_size/ni) | |
for i, sentence in enumerate(sample_sentence): | |
print("seed: %s" % sentence) | |
sentence = preprocess_caption(sentence) | |
sample_sentence[i] = [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(sentence)] + [vocab.end_id] # add END_ID | |
sample_sentence = tl.prepro.pad_sequences(sample_sentence, padding='post') | |
n_epoch = 600 | |
print_freq = 1 | |
n_batch_epoch = int(n_images_train / batch_size) | |
for epoch in range(0, n_epoch+1): | |
start_time = time.time() | |
if epoch !=0 and (epoch % decay_every == 0): | |
new_lr_decay = lr_decay ** (epoch // decay_every) | |
sess.run(tf.assign(lr_v, lr * new_lr_decay)) | |
log = " ** new learning rate: %f" % (lr * new_lr_decay) | |
print(log) | |
elif epoch == 0: | |
log = " ** init lr: %f decay_every_epoch: %d, lr_decay: %f" % (lr, decay_every, lr_decay) | |
print(log) | |
for step in range(n_batch_epoch): | |
step_time = time.time() | |
## get matched text | |
idexs = get_random_int(min=0, max=n_captions_train-1, number=batch_size) | |
b_real_caption = captions_ids_train[idexs] | |
b_real_caption = tl.prepro.pad_sequences(b_real_caption, padding='post') | |
## get real image | |
b_real_images = images_train[np.floor(np.asarray(idexs).astype('float')/n_captions_per_image).astype('int')] | |
## get wrong caption | |
idexs = get_random_int(min=0, max=n_captions_train-1, number=batch_size) | |
b_wrong_caption = captions_ids_train[idexs] | |
b_wrong_caption = tl.prepro.pad_sequences(b_wrong_caption, padding='post') | |
## get wrong image | |
idexs2 = get_random_int(min=0, max=n_images_train-1, number=batch_size) | |
b_wrong_images = images_train[idexs2] | |
## get noise | |
b_z = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32) | |
b_real_images = threading_data(b_real_images, prepro_img, mode='train') # [0, 255] --> [-1, 1] + augmentation | |
b_wrong_images = threading_data(b_wrong_images, prepro_img, mode='train') | |
## updates text-to-image mapping | |
if epoch < 50: | |
errRNN, _ = sess.run([rnn_loss, rnn_optim], feed_dict={ | |
t_real_image : b_real_images, | |
t_wrong_image : b_wrong_images, | |
t_real_caption : b_real_caption, | |
t_wrong_caption : b_wrong_caption}) | |
else: | |
errRNN = 0 | |
## updates D | |
errD, _ = sess.run([d_loss, d_optim], feed_dict={ | |
t_real_image : b_real_images, | |
# t_wrong_image : b_wrong_images, | |
t_wrong_caption : b_wrong_caption, | |
t_real_caption : b_real_caption, | |
t_z : b_z}) | |
## updates G | |
errG, _ = sess.run([g_loss, g_optim], feed_dict={ | |
t_real_caption : b_real_caption, | |
t_z : b_z}) | |
print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4fs, d_loss: %.8f, g_loss: %.8f, rnn_loss: %.8f" \ | |
% (epoch, n_epoch, step, n_batch_epoch, time.time() - step_time, errD, errG, errRNN)) | |
if (epoch + 1) % print_freq == 0: | |
print(" ** Epoch %d took %fs" % (epoch, time.time()-start_time)) | |
img_gen, rnn_out = sess.run([net_g.outputs, net_rnn.outputs], feed_dict={ | |
t_real_caption : sample_sentence, | |
t_z : sample_seed}) | |
save_images(img_gen, [ni, ni], 'samples/step1_gan-cls/train_{:02d}.png'.format(epoch)) | |
## save model | |
if (epoch != 0) and (epoch % 10) == 0: | |
tl.files.save_npz(net_cnn.all_params, name=net_cnn_name, sess=sess) | |
tl.files.save_npz(net_rnn.all_params, name=net_rnn_name, sess=sess) | |
tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess) | |
tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess) | |
print("[*] Save checkpoints SUCCESS!") | |
if (epoch != 0) and (epoch % 100) == 0: | |
tl.files.save_npz(net_cnn.all_params, name=net_cnn_name+str(epoch), sess=sess) | |
tl.files.save_npz(net_rnn.all_params, name=net_rnn_name+str(epoch), sess=sess) | |
tl.files.save_npz(net_g.all_params, name=net_g_name+str(epoch), sess=sess) | |
tl.files.save_npz(net_d.all_params, name=net_d_name+str(epoch), sess=sess) | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--mode', type=str, default="train", | |
help='train, train_encoder, translation') | |
args = parser.parse_args() | |
if args.mode == "train": | |
main_train() | |
## you would not use this part, unless you want to try style transfer on GAN-CLS paper | |
# elif args.mode == "train_encoder": | |
# main_train_encoder() | |
# | |
# elif args.mode == "translation": | |
# main_transaltion() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment