Created
March 13, 2018 13:32
-
-
Save arisliang/9fd2bdf554a0ccae1ce7a1dfd19b11d4 to your computer and use it in GitHub Desktop.
initial cost too small
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
ch 6.4 | |
https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_generate_text.py | |
https://github.com/tensorlayer/tensorlayer/tree/master/example/data/trump | |
""" | |
import tensorflow as tf | |
import tensorlayer as tl | |
from tensorlayer.layers import * | |
import nltk | |
import numpy as np | |
import re | |
print('tf version:', tf.VERSION) | |
print('tl version:', tl.__version__) | |
def customized_clean_str(string): | |
"""Tokenization/string cleaning for a datasets. | |
""" | |
string = re.sub(r"\n", " ", string) # '\n' --> ' ' | |
string = re.sub(r"\'s", " \'s", string) # it's --> it 's | |
string = re.sub(r"\’s", " \'s", string) | |
string = re.sub(r"\'ve", " have", string) # they've --> they have | |
string = re.sub(r"\’ve", " have", string) | |
string = re.sub(r"\'t", " not", string) # can't --> can not | |
string = re.sub(r"\’t", " not", string) | |
string = re.sub(r"\'re", " are", string) # they're --> they are | |
string = re.sub(r"\’re", " are", string) | |
string = re.sub(r"\'d", "", string) # I'd (I had, I would) --> I | |
string = re.sub(r"\’d", "", string) | |
string = re.sub(r"\'ll", " will", string) # I'll --> I will | |
string = re.sub(r"\’ll", " will", string) | |
string = re.sub(r"\“", " “ ", string) # “a” --> “ a ” | |
string = re.sub(r"\”", " ” ", string) | |
string = re.sub(r"\"", " “ ", string) # "a" --> " a " | |
string = re.sub(r"\'", " ' ", string) # they' --> they ' | |
string = re.sub(r"\’", " ' ", string) # they’ --> they ' | |
string = re.sub(r"\.", " . ", string) # they. --> they . | |
string = re.sub(r"\,", " , ", string) # they, --> they , | |
string = re.sub(r"\-", " ", string) # "low-cost"--> lost cost | |
string = re.sub(r"\(", " ( ", string) # (they) --> ( they) | |
string = re.sub(r"\)", " ) ", string) # ( they) --> ( they ) | |
string = re.sub(r"\!", " ! ", string) # they! --> they ! | |
string = re.sub(r"\]", " ] ", string) # they] --> they ] | |
string = re.sub(r"\[", " [ ", string) # they[ --> they [ | |
string = re.sub(r"\?", " ? ", string) # they? --> they ? | |
string = re.sub(r"\>", " > ", string) # they> --> they > | |
string = re.sub(r"\<", " < ", string) # they< --> they < | |
string = re.sub(r"\=", " = ", string) # easier= --> easier = | |
string = re.sub(r"\;", " ; ", string) # easier; --> easier ; | |
string = re.sub(r"\;", " ; ", string) | |
string = re.sub(r"\:", " : ", string) # easier: --> easier : | |
string = re.sub(r"\"", " \" ", string) # easier" --> easier " | |
string = re.sub(r"\$", " $ ", string) # $380 --> $ 380 | |
string = re.sub(r"\_", " _ ", string) # _100 --> _ 100 | |
string = re.sub(r"\s{2,}", " ", string) # Akara is handsome --> Akara is handsome | |
return string.strip().lower() # lowercase | |
def customized_read_words(input_fpath): # , dictionary): | |
with open(input_fpath, "r") as f: | |
words = f.read() | |
# Clean the data | |
words = customized_clean_str(words) | |
# Split each word | |
return words.split() | |
# 模型与训练参数 | |
init_scale = 0.1 | |
learning_rate = 1.0 | |
max_grad_norm = 5 | |
sequence_length = 20 | |
hidden_size = 200 | |
max_epoch = 4 | |
max_max_epoch = 100 | |
lr_decay = 0.9 | |
batch_size = 20 | |
# 词嵌套层参数 | |
vocabulary_size = 50000 | |
# 采样输出参数 | |
top_k_list = [1, 3, 5, 10] | |
print_length = 30 | |
# 保存模型名字 | |
model_file_name = 'model_generate_text.npz' | |
##===== Prepare Data | |
# where's data; manually download from github. | |
words = customized_read_words(input_fpath="data/trump/trump_text.txt") | |
vocab = tl.nlp.create_vocab([words], word_counts_output_file='vocab.txt', min_word_count=1) | |
vocab = tl.nlp.Vocabulary('vocab.txt', unk_word="<UNK>") | |
vocab_size = vocab.unk_id + 1 | |
train_data = [vocab.word_to_id(word) for word in words] | |
# TODO: sequence_length在模型中是什么作用? input和output是相同长度? | |
input_data = tf.placeholder(tf.int32, shape=[batch_size, sequence_length]) | |
targets = tf.placeholder(tf.int32, shape=[batch_size, sequence_length]) | |
# 生成句子时(测试时)使用,序列长度为1,以逐一输入单词. | |
input_data_test = tf.placeholder(tf.int32, shape=[1, 1]) | |
# 网络定义 | |
def inference(x, is_train, sequence_length, reuse=None): | |
print("\nsequence_length: %d, is_train: %s, reuse: %s" % (sequence_length, is_train, reuse)) | |
# is_train没有用到? 网上代码也只是用来打印信息. | |
# 如果reuse是True,则用已声明的参数,具体情况见第二章 | |
rnn_init = tf.random_uniform_initializer(-init_scale, init_scale) | |
with tf.variable_scope('model', reuse=reuse): | |
# tl.layers.set_name_reuse(reuse) | |
# 词嵌套层 | |
# TODO: 这里可以载入预训练的词嵌套矩阵 | |
network = EmbeddingInputlayer(inputs=x, | |
vocabulary_size=vocab_size, | |
embedding_size=hidden_size, | |
E_init=rnn_init, | |
name='embedding') | |
# LSTM层定义 | |
# TODO: hidden_size要和embedding一样? | |
# 这个是fixed length | |
# https: // tensorlayer.readthedocs.io / en / latest / modules / layers.html # tensorlayer.layers.RNNLayer | |
network = RNNLayer(network, | |
cell_fn=tf.contrib.rnn.BasicLSTMCell, | |
cell_init_args={'forget_bias': 0.0, 'state_is_tuple': True}, | |
n_hidden=hidden_size, | |
initializer=rnn_init, | |
n_steps=sequence_length, | |
return_last=False, | |
return_seq_2d=True, | |
name='lstm1') | |
# 把LSTM单独返回,因为我们将会用到它的Cell state和Hidden state | |
lstm1 = network | |
# 输出每隔单词的概率 | |
network = DenseLayer(network, | |
n_units=vocab_size, # 这里要用一样的vocab size? 每个单词的输出概率 | |
W_init=rnn_init, | |
b_init=rnn_init, | |
act=tf.identity, | |
name='output') | |
return network, lstm1 | |
# 训练时模型 | |
network, lstm1 = inference(input_data, is_train=True, sequence_length=sequence_length, reuse=None) | |
# 生成句子时(测试时),序列长度为1 | |
network_test, lstm1_test = inference(input_data_test, is_train=False, sequence_length=1, reuse=True) | |
y_linear = network_test.outputs | |
y_soft = tf.nn.softmax(y_linear) | |
def loss_fn(outputs, targets, batch_size, sequence_length): | |
# TODO: confusing loss definition | |
# sequence_loss_by_example doc | |
# https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py | |
# seq2seq tutorial: | |
# https://www.tensorflow.org/versions/master/tutorials/seq2seq | |
# TODO: update loss function to non-legacy? | |
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/seq2seq/python/ops/loss.py | |
loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([outputs], | |
[tf.reshape(targets, [-1])], | |
[tf.ones([batch_size * sequence_length])]) | |
cost = tf.reduce_mean(loss) / batch_size | |
return cost | |
print('output:', network.outputs.shape) | |
print('targets:', targets.shape) | |
print('batch_size:', batch_size) | |
print('sequence_length:', sequence_length) | |
cost = loss_fn(network.outputs, targets, batch_size, sequence_length) | |
print('cost:', cost) | |
# 定义学习率变量 | |
with tf.variable_scope('learning_rate'): | |
lr = tf.Variable(0.0, trainable=False) | |
# 定义优化器 | |
tvars = network.all_params | |
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), max_grad_norm) | |
optimizer = tf.train.GradientDescentOptimizer(lr) | |
train_op = optimizer.apply_gradients(zip(grads, tvars)) | |
# 初始化参数 | |
sess = tf.InteractiveSession() | |
tl.layers.initialize_global_variables(sess) | |
network.print_layers() | |
network.print_params() | |
# 准备种子句子 | |
seed = 'it is a' | |
seed = nltk.tokenize.word_tokenize(seed) | |
print('seed : %s' % seed) | |
# 开始训练 | |
start_time = time.time() | |
for i in range(max_max_epoch): | |
# 降低学习率 | |
new_lr_decay = lr_decay ** max(i - max_epoch, 0.0) | |
sess.run(tf.assign(lr, learning_rate * new_lr_decay)) | |
# 打印每隔Epoch的信息 | |
print('Epoch: %d/%d learning rate: %.8f' % (i + 1, max_max_epoch, sess.run(lr))) | |
# https://github.com/tensorlayer/chinese-book/issues/20 | |
epoch_size = ((len(train_data) // batch_size) - 1) // sequence_length | |
start_time_epoch = time.time() | |
costs = 0.0; | |
iters = 0 | |
# 每隔Epoch开始时,把Cell State和Hidden State都置零 | |
state1 = tl.layers.initialize_rnn_state(lstm1.initial_state) | |
# tl.layers.initialize_rnn_state(lstm1.initial_state) | |
for step, (x, y) in enumerate(tl.iterate.ptb_iterator(train_data, | |
batch_size, | |
sequence_length)): | |
print('x:', x.shape, x[:10]) | |
print('y:', y.shape, y[:10]) | |
print('s0:', state1[0].shape, state1[0][:10]) | |
print('s1:', state1[1].shape, state1[1][:10]) | |
# 每次更新后,把Cell State和Hidden State作为下一次更新的初始值 | |
_cost, state1, _ = sess.run([cost, lstm1.final_state, train_op], | |
feed_dict={input_data: x, | |
targets: y, | |
lstm1.initial_state: state1}) | |
print('step: {0}, cost: {1}, iters: {2}'.format(step, _cost, iters)) | |
costs += _cost; | |
iters += sequence_length | |
# 每隔一段时间,打印损失值 | |
if step % (epoch_size // 10) == 1: | |
print('%.3f perplexity: %.3f speed: %.0f wps' % | |
(step * 1.0 / epoch_size, np.exp(costs / iters), | |
iters * batch_size / (time.time() - start_time_epoch))) | |
input("Press Enter to continue...") | |
# 打印一个Epoch的损失值 | |
train_perplexity = np.exp(costs / iters) | |
print('Epoch: %d/%d Train Perplexity: %.3f' % | |
(i + 1, max_max_epoch, train_perplexity)) | |
# 生成句子 | |
for top_k in top_k_list: | |
# Cell state和Hidden state置零 | |
state1 = tl.layers.initialize_rnn_state(lstm1_test.initial_state) | |
# 序列化种子句子 | |
outs_id = [vocab.word_to_id(w) for w in seed] | |
# print('outs_id:', outs_id) | |
# 把种子句子输入LSTM,以得到生成句子使用的元胞状态和隐藏状态 | |
for ids in outs_id[:-1]: | |
a_id = np.asarray(ids).reshape(1, 1) | |
state1 = sess.run([lstm1_test.final_state, ], | |
feed_dict={input_data_test: a_id, lstm1_test.initial_state: state1}) | |
# 输入种子句子最后一个单词,开始生成句子 | |
a_id = outs_id[-1] | |
for _ in range(print_length): | |
a_id = np.asarray(a_id).reshape(1, 1) | |
out, state1 = sess.run([y_soft, lstm1_test.final_state], | |
feed_dict={input_data_test: a_id, | |
lstm1_test.initial_state: state1}) | |
# print('out:', out.shape) | |
# print('state1:', state1[0].shape, state1[1].shape) | |
# Top K 采样 | |
a_id = tl.nlp.sample_top(out[0], top_k=top_k) | |
# print('a_id:', a_id) | |
outs_id.append(a_id) | |
# 把生成的句子以字符串形式打印出来 | |
sentence = [vocab.id_to_word(w) for w in outs_id] | |
sentence = ' '.join(sentence) | |
print('top ', top_k, ':', sentence) | |
print("Save model") | |
tl.files.save_npz(network_test.all_params, name=model_file_name) | |
# TODO: ch 6.4.4 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment