Skip to content

Instantly share code, notes, and snippets.

@robertmaxwilliams
Created November 19, 2017 06:16
Show Gist options
  • Save robertmaxwilliams/e1dfe63fa871afbb336ff423367b3f20 to your computer and use it in GitHub Desktop.
Save robertmaxwilliams/e1dfe63fa871afbb336ff423367b3f20 to your computer and use it in GitHub Desktop.
#! /usr/bin/python
# -*- coding: utf8 -*-
""" GAN-CLS """
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *
from tensorlayer.prepro import *
from tensorlayer.cost import *
import numpy as np
import scipy
from scipy.io import loadmat
import time, os, re, nltk
from utils import *
from model import *
import model
import os
import re
import time
import nltk
import string
import tensorlayer as tl
from utils import *
dataset = '102flowers'
need_256 = True # set to True for stackGAN
if dataset == '102flowers':
"""
images.shape = [8000, 64, 64, 3]
captions_ids = [80000, any]
"""
cwd = os.getcwd()
img_dir = os.path.join(cwd, '102flowers')
caption_dir = os.path.join(cwd, 'text_c10')
VOC_FIR = cwd + '/vocab.txt'
## load captions
caption_sub_dir = load_folder_list( caption_dir )
captions_dict = {}
processed_capts = []
for sub_dir in caption_sub_dir: # get caption file list
with tl.ops.suppress_stdout():
files = tl.files.load_file_list(path=sub_dir, regx='^image_[0-9]+\.txt')
for i, f in enumerate(files):
file_dir = os.path.join(sub_dir, f)
key = int(re.findall('\d+', f)[0])
t = open(file_dir,'r')
lines = []
for line in t:
line = preprocess_caption(line)
lines.append(line)
processed_capts.append(tl.nlp.process_sentence(line, start_word="<S>", end_word="</S>"))
assert len(lines) == 10, "Every flower image have 10 captions"
captions_dict[key] = lines
print(" * %d x %d captions found " % (len(captions_dict), len(lines)))
## build vocab
if not os.path.isfile('vocab.txt'):
_ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1)
else:
print("WARNING: vocab.txt already exists")
vocab = tl.nlp.Vocabulary(VOC_FIR, start_word="<S>", end_word="</S>", unk_word="<UNK>")
## store all captions ids in list
captions_ids = []
try: # python3
tmp = captions_dict.items()
except: # python3
tmp = captions_dict.iteritems()
for key, value in tmp:
for v in value:
captions_ids.append( [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(v)] + [vocab.end_id]) # add END_ID
captions_ids = np.asarray(captions_ids)
print(" * tokenized %d captions" % len(captions_ids))
## check
img_capt = captions_dict[1][1]
print("img_capt: %s" % img_capt)
print("nltk.tokenize.word_tokenize(img_capt): %s" % nltk.tokenize.word_tokenize(img_capt))
img_capt_ids = [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(img_capt)]#img_capt.split(' ')]
print("img_capt_ids: %s" % img_capt_ids)
print("id_to_word: %s" % [vocab.id_to_word(id) for id in img_capt_ids])
## load images
with tl.ops.suppress_stdout(): # get image files list
imgs_title_list = sorted(tl.files.load_file_list(path=img_dir, regx='^image_[0-9]+\.jpg'))
print(" * %d images found, start loading and resizing ..." % len(imgs_title_list))
s = time.time()
images = []
images_256 = []
for name in imgs_title_list:
img_raw = scipy.misc.imread( os.path.join(img_dir, name) )
img = tl.prepro.imresize(img_raw, size=[64, 64]) # (64, 64, 3)
img = img.astype(np.float32)
images.append(img)
if need_256:
img = tl.prepro.imresize(img_raw, size=[256, 256]) # (256, 256, 3)
img = img.astype(np.float32)
images_256.append(img)
print(" * loading and resizing took %ss" % (time.time()-s))
n_images = len(captions_dict)
n_captions = len(captions_ids)
n_captions_per_image = len(lines) # 10
print("n_captions: %d n_images: %d n_captions_per_image: %d" % (n_captions, n_images, n_captions_per_image))
captions_ids_train, captions_ids_test = captions_ids[: 8000*n_captions_per_image], captions_ids[8000*n_captions_per_image :]
images_train, images_test = images[:8000], images[8000:]
if need_256:
images_train_256, images_test_256 = images_256[:8000], images_256[8000:]
n_images_train = len(images_train)
n_images_test = len(images_test)
n_captions_train = len(captions_ids_train)
n_captions_test = len(captions_ids_test)
print("n_images_train:%d n_captions_train:%d" % (n_images_train, n_captions_train))
print("n_images_test:%d n_captions_test:%d" % (n_images_test, n_captions_test))
if dataset=='kickstarter':
cwd = os.cwd()
data_dir = ps.path.join(cwd, 'kickstarter')
# list of tuples, format as (caption, id)
captions = list()
ids = list
##load captions
## put caption line in dict, key is filename minus .txt
for filename in os.listdir(data_dir):
if filename.endswith('.txt'):
with open(filename) as myfile:
captions.append(myfile.readlines()[2].strip())
ids.append(filename[-4])
# TODO create vocab.txt if it no exist
pairs = Counter(' '.join(captions).split()).items()
pairs.sort(lambda word,count: count)
with open("vocab.txt", 'w') as f:
for word,count in pairs:
f.write(word + ' ' + str(count)
# use vocab.txt to tokenize captions and convert to array of tokenized sentences
caption_ids = list()
vocab = tl.nlp.Vocabulary(VOC_FIR, start_word="<S>", end_word="</S>", unk_word="<UNK>")
for caption in captions:
captions_ids.append( [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(captions)] + [vocab.end_id]) # add END_ID
captions_ids = np.asarray(captions_ids)
# use file id (text filenames without extenions) to open all images using tl
images = list()
images_256 = list()
for idd in ids:
img_raw = scipy.misc.imread( os.path.join(img_dir, idd) )
img = tl.prepro.imresize(img_raw, size=[64, 64]) # (64, 64, 3)
img = img.astype(np.float32)
images.append(img)
# also get them in 256 size
img = tl.prepro.imresize(img_raw, size=[256, 256]) # (256, 256, 3)
img = img.astype(np.float32)
images_256.append(img)
n_images = len(images)
n_captions = len(captions)
print("Number of images and captions: ", n_image, n_captions)
n_captions_per_image = 1
# split into test and train (test will be first 4000 samples)
images_test, images_train = images[0:4000], images[4000:]
caption_ids_test, caption_ids_train = caption_ids[0:4000], caption_ids[4000:]
images_train = np.array(images_train)
images_test = np.array(images_test)
ni = int(np.ceil(np.sqrt(batch_size)))
tl.files.exists_or_mkdir("samples/step1_gan-cls")
tl.files.exists_or_mkdir("samples/step_pretrain_encoder")
tl.files.exists_or_mkdir("checkpoint")
save_dir = "checkpoint"
def main_train():
###======================== DEFIINE MODEL ===================================###
t_real_image = tf.placeholder('float32', [batch_size, image_size, image_size, 3], name = 'real_image')
t_wrong_image = tf.placeholder('float32', [batch_size ,image_size, image_size, 3], name = 'wrong_image')
t_real_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='real_caption_input')
t_wrong_caption = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='wrong_caption_input')
t_z = tf.placeholder(tf.float32, [batch_size, z_dim], name='z_noise')
## training inference for text-to-image mapping
net_cnn = cnn_encoder(t_real_image, is_train=True, reuse=False)
x = net_cnn.outputs
v = rnn_embed(t_real_caption, is_train=True, reuse=False).outputs
x_w = cnn_encoder(t_wrong_image, is_train=True, reuse=True).outputs
v_w = rnn_embed(t_wrong_caption, is_train=True, reuse=True).outputs
alpha = 0.2 # margin alpha
rnn_loss = tf.reduce_mean(tf.maximum(0., alpha - cosine_similarity(x, v) + cosine_similarity(x, v_w))) + \
tf.reduce_mean(tf.maximum(0., alpha - cosine_similarity(x, v) + cosine_similarity(x_w, v)))
## training inference for txt2img
generator_txt2img = model.generator_txt2img_resnet
discriminator_txt2img = model.discriminator_txt2img_resnet
net_rnn = rnn_embed(t_real_caption, is_train=False, reuse=True)
net_fake_image, _ = generator_txt2img(t_z,
net_rnn.outputs,
is_train=True, reuse=False, batch_size=batch_size)
net_d, disc_fake_image_logits = discriminator_txt2img(
net_fake_image.outputs, net_rnn.outputs, is_train=True, reuse=False)
_, disc_real_image_logits = discriminator_txt2img(
t_real_image, net_rnn.outputs, is_train=True, reuse=True)
_, disc_mismatch_logits = discriminator_txt2img(
t_real_image,
rnn_embed(t_wrong_caption, is_train=False, reuse=True).outputs,
is_train=True, reuse=True)
## testing inference for txt2img
net_g, _ = generator_txt2img(t_z,
rnn_embed(t_real_caption, is_train=False, reuse=True).outputs,
is_train=False, reuse=True, batch_size=batch_size)
d_loss1 = tl.cost.sigmoid_cross_entropy(disc_real_image_logits, tf.ones_like(disc_real_image_logits), name='d1')
d_loss2 = tl.cost.sigmoid_cross_entropy(disc_mismatch_logits, tf.zeros_like(disc_mismatch_logits), name='d2')
d_loss3 = tl.cost.sigmoid_cross_entropy(disc_fake_image_logits, tf.zeros_like(disc_fake_image_logits), name='d3')
d_loss = d_loss1 + (d_loss2 + d_loss3) * 0.5
g_loss = tl.cost.sigmoid_cross_entropy(disc_fake_image_logits, tf.ones_like(disc_fake_image_logits), name='g')
####======================== DEFINE TRAIN OPTS ==============================###
lr = 0.0002
lr_decay = 0.5 # decay factor for adam, https://github.com/reedscot/icml2016/blob/master/main_cls_int.lua https://github.com/reedscot/icml2016/blob/master/scripts/train_flowers.sh
decay_every = 100 # https://github.com/reedscot/icml2016/blob/master/main_cls.lua
beta1 = 0.5
cnn_vars = tl.layers.get_variables_with_name('cnn', True, True)
rnn_vars = tl.layers.get_variables_with_name('rnn', True, True)
d_vars = tl.layers.get_variables_with_name('discriminator', True, True)
g_vars = tl.layers.get_variables_with_name('generator', True, True)
with tf.variable_scope('learning_rate'):
lr_v = tf.Variable(lr, trainable=False)
d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars )
g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars )
grads, _ = tf.clip_by_global_norm(tf.gradients(rnn_loss, rnn_vars + cnn_vars), 10)
optimizer = tf.train.AdamOptimizer(lr_v, beta1=beta1)# optimizer = tf.train.GradientDescentOptimizer(lre)
rnn_optim = optimizer.apply_gradients(zip(grads, rnn_vars + cnn_vars))
###============================ TRAINING ====================================###
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
tl.layers.initialize_global_variables(sess)
# load the latest checkpoints
net_rnn_name = os.path.join(save_dir, 'net_rnn.npz')
net_cnn_name = os.path.join(save_dir, 'net_cnn.npz')
net_g_name = os.path.join(save_dir, 'net_g.npz')
net_d_name = os.path.join(save_dir, 'net_d.npz')
load_and_assign_npz(sess=sess, name=net_rnn_name, model=net_rnn)
load_and_assign_npz(sess=sess, name=net_cnn_name, model=net_cnn)
load_and_assign_npz(sess=sess, name=net_g_name, model=net_g)
load_and_assign_npz(sess=sess, name=net_d_name, model=net_d)
## seed for generation, z and sentence ids
sample_size = batch_size
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32)
sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."] * int(sample_size/ni) + \
["this flower has petals that are yellow, white and purple and has dark lines"] * int(sample_size/ni) + \
["the petals on this flower are white with a yellow center"] * int(sample_size/ni) + \
["this flower has a lot of small round pink petals."] * int(sample_size/ni) + \
["this flower is orange in color, and has petals that are ruffled and rounded."] * int(sample_size/ni) + \
["the flower has yellow petals and the center of it is brown."] * int(sample_size/ni) + \
["this flower has petals that are blue and white."] * int(sample_size/ni) +\
["these white flowers have petals that start off white in color and end in a white towards the tips."] * int(sample_size/ni)
for i, sentence in enumerate(sample_sentence):
print("seed: %s" % sentence)
sentence = preprocess_caption(sentence)
sample_sentence[i] = [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(sentence)] + [vocab.end_id] # add END_ID
sample_sentence = tl.prepro.pad_sequences(sample_sentence, padding='post')
n_epoch = 600
print_freq = 1
n_batch_epoch = int(n_images_train / batch_size)
for epoch in range(0, n_epoch+1):
start_time = time.time()
if epoch !=0 and (epoch % decay_every == 0):
new_lr_decay = lr_decay ** (epoch // decay_every)
sess.run(tf.assign(lr_v, lr * new_lr_decay))
log = " ** new learning rate: %f" % (lr * new_lr_decay)
print(log)
elif epoch == 0:
log = " ** init lr: %f decay_every_epoch: %d, lr_decay: %f" % (lr, decay_every, lr_decay)
print(log)
for step in range(n_batch_epoch):
step_time = time.time()
## get matched text
idexs = get_random_int(min=0, max=n_captions_train-1, number=batch_size)
b_real_caption = captions_ids_train[idexs]
b_real_caption = tl.prepro.pad_sequences(b_real_caption, padding='post')
## get real image
b_real_images = images_train[np.floor(np.asarray(idexs).astype('float')/n_captions_per_image).astype('int')]
## get wrong caption
idexs = get_random_int(min=0, max=n_captions_train-1, number=batch_size)
b_wrong_caption = captions_ids_train[idexs]
b_wrong_caption = tl.prepro.pad_sequences(b_wrong_caption, padding='post')
## get wrong image
idexs2 = get_random_int(min=0, max=n_images_train-1, number=batch_size)
b_wrong_images = images_train[idexs2]
## get noise
b_z = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32)
b_real_images = threading_data(b_real_images, prepro_img, mode='train') # [0, 255] --> [-1, 1] + augmentation
b_wrong_images = threading_data(b_wrong_images, prepro_img, mode='train')
## updates text-to-image mapping
if epoch < 50:
errRNN, _ = sess.run([rnn_loss, rnn_optim], feed_dict={
t_real_image : b_real_images,
t_wrong_image : b_wrong_images,
t_real_caption : b_real_caption,
t_wrong_caption : b_wrong_caption})
else:
errRNN = 0
## updates D
errD, _ = sess.run([d_loss, d_optim], feed_dict={
t_real_image : b_real_images,
# t_wrong_image : b_wrong_images,
t_wrong_caption : b_wrong_caption,
t_real_caption : b_real_caption,
t_z : b_z})
## updates G
errG, _ = sess.run([g_loss, g_optim], feed_dict={
t_real_caption : b_real_caption,
t_z : b_z})
print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4fs, d_loss: %.8f, g_loss: %.8f, rnn_loss: %.8f" \
% (epoch, n_epoch, step, n_batch_epoch, time.time() - step_time, errD, errG, errRNN))
if (epoch + 1) % print_freq == 0:
print(" ** Epoch %d took %fs" % (epoch, time.time()-start_time))
img_gen, rnn_out = sess.run([net_g.outputs, net_rnn.outputs], feed_dict={
t_real_caption : sample_sentence,
t_z : sample_seed})
save_images(img_gen, [ni, ni], 'samples/step1_gan-cls/train_{:02d}.png'.format(epoch))
## save model
if (epoch != 0) and (epoch % 10) == 0:
tl.files.save_npz(net_cnn.all_params, name=net_cnn_name, sess=sess)
tl.files.save_npz(net_rnn.all_params, name=net_rnn_name, sess=sess)
tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
print("[*] Save checkpoints SUCCESS!")
if (epoch != 0) and (epoch % 100) == 0:
tl.files.save_npz(net_cnn.all_params, name=net_cnn_name+str(epoch), sess=sess)
tl.files.save_npz(net_rnn.all_params, name=net_rnn_name+str(epoch), sess=sess)
tl.files.save_npz(net_g.all_params, name=net_g_name+str(epoch), sess=sess)
tl.files.save_npz(net_d.all_params, name=net_d_name+str(epoch), sess=sess)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default="train",
help='train, train_encoder, translation')
args = parser.parse_args()
if args.mode == "train":
main_train()
## you would not use this part, unless you want to try style transfer on GAN-CLS paper
# elif args.mode == "train_encoder":
# main_train_encoder()
#
# elif args.mode == "translation":
# main_transaltion()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment