Created
March 11, 2018 11:03
-
-
Save arisliang/a197b17b6330a86a56e500907dcd07c5 to your computer and use it in GitHub Desktop.
使用AdagradOptimizer会有错误
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
ch 5.3 | |
https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_word2vec_basic.py | |
""" | |
import collections, math, os, random, time | |
import numpy as np | |
from six.moves import xrange | |
import tensorflow as tf | |
import tensorlayer as tl | |
# 1. 下载数据 | |
words = tl.files.load_matt_mahoney_text8_dataset() | |
data_size = len(words) | |
print('data size: {0}'.format(data_size)) | |
print(words[0:10]) | |
# 2. 训练模式(参数) | |
# 词汇表不存在的单词用_UNK表示 | |
_UNK = "_UNK" | |
# 词汇表大小,特征向量长度,批大小 | |
vocabulary_size = 50000 | |
batch_size = 128 | |
embedding_size = 128 | |
# Skip Gram的参数 | |
skip_window = 1 | |
num_skips = 2 | |
# Negative Sampling是的词数量 | |
num_sampled = 64 | |
# 学习率与总Epoch数 | |
learning_rate = 1.0 | |
n_epoch = 20 | |
# 保存模型名字 | |
model_file_name = './saved/model_word2vec_50k_128' | |
# 一个Epoch需要训练的次数 | |
# 所有Epoch | |
num_steps = int((data_size / batch_size) * n_epoch) | |
# 词汇表, see ch9 高级使用技巧 | |
data, count, dictionary, reverse_dictionary = tl.nlp.build_words_dataset(words, vocabulary_size, True, _UNK) | |
print('Most 5 common words (+UNK)', count[:5]) | |
print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]]) | |
# 3. 生成batch数据 | |
print('display some sample data 1') | |
batch, labels, data_index = tl.nlp.generate_skip_gram_batch(data=data, batch_size=8, num_skips=2, | |
skip_window=1, data_index=0) | |
for i in range(8): | |
print(batch[i], reverse_dictionary[batch[i]], '->', labels[i, 0], reverse_dictionary[labels[i, 0]]) | |
print('display some sample data 2') | |
batch, labels, data_index = tl.nlp.generate_skip_gram_batch(data=data, batch_size=8, num_skips=4, | |
skip_window=2, data_index=0) | |
for i in range(8): | |
print(batch[i], reverse_dictionary[batch[i]], '->', labels[i, 0], reverse_dictionary[labels[i, 0]]) | |
# 4. 建立模型 | |
valid_size = 16 | |
valid_window = 100 | |
valid_examples = np.random.choice(valid_window, valid_size, replace=False) | |
print_freq = 2000 | |
train_inputs = tf.placeholder(tf.int32, shape=[batch_size]) | |
train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1]) | |
valid_dataset = tf.constant(valid_examples, dtype=tf.int32) | |
emb_net = tl.layers.Word2vecEmbeddingInputlayer(inputs=train_inputs, | |
train_labels=train_labels, | |
vocabulary_size=vocabulary_size, | |
embedding_size=embedding_size, | |
num_sampled=num_sampled, | |
# nce_loss_args={}, | |
# E_init=tf.random_uniform_initializer(minval=-1.0, maxval=1.0), | |
# E_init_args={}, | |
# nce_W_init=tf.truncated_normal_initializer(stddev=float(1.0 / np.sqrt(embedding_size))), | |
# nce_W_init_args={}, | |
# nce_b_init=tf.constant_initializer(value=0.0), | |
# nce_b_init_args={}, | |
name='word2vec_layer') | |
cost = emb_net.nce_cost | |
train_params = emb_net.all_params | |
learning_rate = 1.0 | |
train_op = tf.train.AdagradOptimizer(learning_rate).minimize(cost, var_list=train_params) | |
# learning_rate = 0.0001 | |
# train_op = tf.train.AdamOptimizer(learning_rate).minimize(cost, var_list=train_params) | |
normalized_embeddings = emb_net.normalized_embeddings | |
valid_embed = tf.nn.embedding_lookup(normalized_embeddings, valid_dataset) | |
similarity = tf.matmul(valid_embed, normalized_embeddings, transpose_b=True) | |
# 5. 训练模型 | |
sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True)) | |
# 参数初始化 | |
tl.layers.initialize_global_variables(sess) | |
# 打印模型信息 | |
emb_net.print_params() | |
emb_net.print_layers() | |
# 保存词汇表 | |
vocab_file_name = './data/vocab_text8.txt' | |
tl.nlp.save_vocab(count, name=vocab_file_name) | |
# 开始训练 | |
average_loss = 0 | |
step = 0 | |
print_freq = 2000 | |
start_time = time.time() | |
# data_index = 0 | |
# print('data index', data_index) | |
while (step < num_steps): | |
# 迭代训练一部 | |
start_time_step = time.time() | |
# print('data index before:', data_index) | |
batch_inputs, batch_labels, data_index = tl.nlp.generate_skip_gram_batch(data=data, | |
batch_size=batch_size, | |
num_skips=num_skips, | |
skip_window=skip_window, | |
data_index=data_index) | |
# data_index 输入输出都有 | |
# print('data index after:', data_index) | |
feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels} | |
_, loss_val = sess.run([train_op, cost], feed_dict=feed_dict) | |
average_loss += loss_val | |
# print('loss %f at step %d' % (loss_val, step)) | |
# 每隔2000步,打印损失值 | |
if step % print_freq == 0: | |
# print('step %d/%d. total loss: %.2f. took: %.3fs / %.0fs' % | |
# (step, num_steps, average_loss, time.time() - start_time_step, time.time() - start_time)) | |
if step > 0: | |
average_loss = average_loss / print_freq | |
print('step %d/%d. average loss: %.2f. took: %.3fs / %.0fs' % | |
(step, num_steps, average_loss, time.time() - start_time_step, time.time() - start_time)) | |
average_loss = 0 | |
# 每隔10000步, 打印与测试单词最相近的8个单词 | |
if step % (print_freq * 5) == 0: | |
sim = similarity.eval() | |
for i in xrange(valid_size): | |
valid_word = reverse_dictionary[valid_examples[i]] | |
top_k = 8 | |
nearest = (-sim[i, :]).argsort()[1:top_k + 1] | |
log_str = 'Nearest to %s:' % valid_word | |
for k in xrange(top_k): | |
close_word = reverse_dictionary[nearest[k]] | |
log_str = '%s %s,' % (log_str, close_word) | |
print(log_str) | |
# 每隔40000步,保存模型和字典 | |
if (step % (print_freq * 20) == 0):# and (step != 0): | |
print('Save model, data and dictionaries' + '!' * 10) | |
tl.files.save_npz(emb_net.all_params, name=model_file_name + '.npz') | |
tl.files.save_any_to_npy(save_dict={'data': data, | |
'count': count, | |
'dictionary': dictionary, | |
'reverse_dictionary': reverse_dictionary}, | |
name=model_file_name + '.npy') | |
step += 1 | |
sess.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment