Skip to content

Instantly share code, notes, and snippets.

@subaochen
Created May 6, 2019 13:12
Show Gist options
  • Save subaochen/3b4060a91394424e92ca5a519f460fef to your computer and use it in GitHub Desktop.
Save subaochen/3b4060a91394424e92ca5a519f460fef to your computer and use it in GitHub Desktop.
使用RNN生成文本的练习-基于莎士比亚戏剧
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 使用RNN生成文本-shakespeare"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"这是学习tensorflow官网资料:https://tensorflow.google.cn/tutorials/sequences/text_generation 的笔记,通过RNN喂入莎士比亚的戏剧文本,尝试让电脑自己写出莎士比亚风格的文章。运行这个简单的例子需要强大的GPU,在我的笔记本上(MX 150只有2G显存)无法运行,如果只使用CPU需要较长的时间,需要有心理准备。可以在google colab上面运行测试,速度10x以上的提升。\n",
"\n",
"这是一个many to many的示例。实际上,RNN可能有下图所示的几种模式(参见:http://karpathy.github.io/2015/05/21/rnn-effectiveness/):\n",
"![diags](http://softlab.sdut.edu.cn/blog/subaochen/wp-content/uploads/sites/4/2019/05/diags.jpeg)\n",
"\n",
"@TODO\n",
"\n",
"* 加入LSTM重新测试"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 启用eager execution"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"tensorflow 1.x默认没有启用eager execution,因此需要明确执行`enable_eager_execution()`打开这个开关。只有1.11以上版本才支持eager execution。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"tf.enable_eager_execution()\n",
"\n",
"import numpy as np\n",
"import os\n",
"import time"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 下载和观察数据"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"只要使用`tf.keras`中的方法下载的数据,默认都存放到了\\$HOME/.keras/datasets目录下。下面是我的.keras/datasets目录的内容:\n",
"```shell\n",
"~/.keras/datasets$ ls\n",
"auto-mpg.data cifar-10-batches-py.tar.gz iris_test.csv\n",
"cifar-100-python fashion-mnist iris_training.csv\n",
"cifar-100-python.tar.gz imdb.npz mnist.npz\n",
"cifar-10-batches-py imdb_word_index.json shakespeare.txt\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/subaochen/.keras/datasets/shakespeare.txt\n"
]
}
],
"source": [
"path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')\n",
"print(path_to_file)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"这里不使用`tf.data.Dataset.TextlineDataset`?也许是因为需要进一步对文本进行分拆处理的缘故?\n",
"\n",
"也没有使用`pandas`提供的方法?\n",
"\n",
"有机会尝试使用`Dataset`或`pandas`改写这个部分。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Length of text: 1115394 characters\n"
]
}
],
"source": [
"text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n",
"# length of text is the number of characters in it\n",
"print ('Length of text: {} characters'.format(len(text)))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"First Citizen:\n",
"Before we proceed any further, hear me speak.\n",
"\n",
"All:\n",
"Speak, speak.\n",
"\n",
"First Citizen:\n",
"You are all resolved rather to die than to famish?\n",
"\n",
"All:\n",
"Resolved. resolved.\n",
"\n",
"First Citizen:\n",
"First, you know Caius Marcius is chief enemy to the people.\n",
"\n"
]
}
],
"source": [
"# Take a look at the first 1000 characters in text\n",
"print(text[:250])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 文本向量化\n",
"文本向量化才能喂入RNN学习,需要三个步骤:\n",
"1. 构造文本字典vocab\n",
"1. 建立字典索引char2idx,将字典的每一个字符映射为数字\n",
"1. 使用char2idx将文本数字化(向量化)\n",
"\n",
"<div class=\"alert alert-block alert-info\">\n",
"<b>Tip:</b> 使用tf.data.Dataset.map方法可以更方便的处理文本向量化?不过就无法观察向量化文本的过程了。\n",
"</div>"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"65 unique characters\n"
]
}
],
"source": [
"# The unique characters in the file\n",
"vocab = sorted(set(text)) # sorted保证了集合的顺序\n",
"print ('{} unique characters'.format(len(vocab)))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([18, 47, 56, 57, 58, 1, 15, 47, 58, 47])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Creating a mapping from unique characters to indices\n",
"char2idx = {u:i for i, u in enumerate(vocab)}\n",
"# vocab是有序集合,转化为数组后其下标自然就是序号,但是不如char2idx结构直观\n",
"# 如果模仿char2idx也很简单:idx2char = {i:u for i,u in enumerate(vocab)}\n",
"idx2char = np.array(vocab)\n",
"\n",
"text_as_int = np.array([char2idx[c] for c in text])\n",
"text_as_int[:10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"各种方式观察一下向量化后的文本。这里没有使用matplotlib,没有太大意义。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"'\\n' ---> 0\n",
"' ' ---> 1\n",
"'!' ---> 2\n",
"'$' ---> 3\n",
"'&' ---> 4\n",
"\"'\" ---> 5\n",
"',' ---> 6\n",
"'-' ---> 7\n",
"'.' ---> 8\n",
"'3' ---> 9\n",
"':' ---> 10\n",
"';' ---> 11\n",
"'?' ---> 12\n",
"'A' ---> 13\n",
"'B' ---> 14\n",
"'C' ---> 15\n",
"'D' ---> 16\n",
"'E' ---> 17\n",
"'F' ---> 18\n",
"'G' ---> 19\n"
]
}
],
"source": [
"# 取出char2idx前20个元素的奇怪写法。zip方法返回成对的元组,range(20)提供了序号。\n",
"for char,_ in zip(char2idx, range(20)):\n",
" print('{:6s} ---> {:4d}'.format(repr(char), char2idx[char]))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"First Citizen ---- characters mapped to int ---- > [18 47 56 57 58 1 15 47 58 47 64 43 52]\n"
]
}
],
"source": [
"# Show how the first 13 characters from the text are mapped to integers\n",
"print ('{} ---- characters mapped to int ---- > {}'.format(text[:13], text_as_int[:13]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 构造训练数据(样本数据)\n",
"把数据喂给RNN之前,需要构造/划分好训练数据和验证数据。在这里,无需验证和测试数据,因此只需要划分好训练数据即可。下面的代码中,每次喂给RNN的训练数据是seq_length个字符。\n",
"\n",
"但是,实际内部处理时,RNN还是要一个一个字符消化,即RNN的输入维度是len(vocab),参见下图(出处:http://karpathy.github.io/2015/05/21/rnn-effectiveness/ ):\n",
"![charseq](http://softlab.sdut.edu.cn/blog/subaochen/wp-content/uploads/sites/4/2019/05/charseq.jpeg)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<DatasetV1Adapter shapes: (), types: tf.int64>\n",
"WARNING:tensorflow:From /home/subaochen/anaconda3/envs/tf1-cpu/lib/python3.7/site-packages/tensorflow/python/data/ops/iterator_ops.py:532: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Colocations handled automatically by placer.\n",
"F\n",
"i\n",
"r\n",
"s\n",
"t\n"
]
}
],
"source": [
"# The maximum length sentence we want for a single input in characters\n",
"# 每次喂入RNN的字符数。注意和后面的BATCH_SIZE的区别以及匹配\n",
"# 为了更好的观察数据,初始的时候seq_length可以设置为10,但是执行时要恢复为100或者\n",
"# 更大的数。当然,也可以测试不同的seq_length下的结果\n",
"seq_length = 100\n",
"examples_per_epoch = len(text)//seq_length\n",
"\n",
"char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)\n",
"print(char_dataset)\n",
"for i in char_dataset.take(5):\n",
" print(idx2char[i.numpy()])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou '\n",
"'are all resolved rather to die than to famish?\\n\\nAll:\\nResolved. resolved.\\n\\nFirst Citizen:\\nFirst, you k'\n",
"\"now Caius Marcius is chief enemy to the people.\\n\\nAll:\\nWe know't, we know't.\\n\\nFirst Citizen:\\nLet us ki\"\n",
"\"ll him, and we'll have corn at our own price.\\nIs't a verdict?\\n\\nAll:\\nNo more talking on't; let it be d\"\n",
"'one: away, away!\\n\\nSecond Citizen:\\nOne word, good citizens.\\n\\nFirst Citizen:\\nWe are accounted poor citi'\n"
]
}
],
"source": [
"# sequences也是一个Dataset对象,但是经过了batch操作进行数据分组,每一个batch的数据\n",
"# 长度是seq_length+1(101).sequences用来创建输入文本和目标文本(长度为seq_length)\n",
"# 注意:这里的batch操作和训练模型时的BATCH_SIZE没有关系,这里的batch操作纯粹\n",
"# 为了按照指定的尺寸切分数据\n",
"sequences = char_dataset.batch(seq_length+1, drop_remainder=True)\n",
"# repl函数的意义相当于Java的toString方法\n",
"# 注意,这里的item已经是tensor了,通过numpy()方法转化为numpy矩阵(向量)\n",
"# numpy数组(List)的强大之处:允许接受一个list作为索引参数,因此idx2char[item.numpy()]即为根据item\n",
"# 的数字为索引获得字符构造出一个字符串\n",
"for item in sequences.take(5):\n",
" print(repr(''.join(idx2char[item.numpy()])))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 创建输入文本和目标文本\n",
"输入文本即参数,目标文本相当于“标签”,预测文本将和目标文本比较以计算误差。\n",
"目标文本(target)和输入(input)文本的关系:目标文本和输入文本正好错开一个字符,即目标文本的第一个字符恰好是输入文本的第二个字符,以此类推。\n",
"\n",
"注意下面的代码中,dataset的shape变化。"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<DatasetV1Adapter shapes: ((100,), (100,)), types: (tf.int64, tf.int64)>\n"
]
}
],
"source": [
"def split_input_target(chunk):\n",
" input_text = chunk[:-1] # 不包括-1即最后一个字符,总共100个字符。这就是为什么chunk的长度是101的原因\n",
" target_text = chunk[1:]\n",
" return input_text, target_text\n",
" \n",
"# 注意到,sequences已经是被batch过的了,因此这里的map是针对每个batch的数据来进行的\n",
"# 此时dataset的结果已经比较复杂了,所谓的nested structure of tensors\n",
"# print(dateset)的结果显示其shape为:shapes: ((10,), (10,))\n",
"# 即,dataset是一个tuple,tuple的每个数据又包含两个tuple,每个tuple是seq_length\n",
"# 长度的向量。其中第一个tuple是input_example,第二个tuple是target_example\n",
"dataset = sequences.map(split_input_target)\n",
"print(dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"input_example就是输入样本,target_example就是目标样本\n",
"可以看出,这里的输入样本和目标样本的尺寸都是seq_length,整个文本被batch_size\n",
"分割成了len(text_as_int)/seq_length组输入样本和目标样本\n",
"\n",
"训练的时候是成对喂入输入样本和目标样本的:但是,其实内部还是一个字符一个字符来计算的,即先取输入样本的第一个字符作为x和目标样本的第一个字符作为y,然后依次处理完输入样本和目标样本的每一个字符,这个batch计算完毕。"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input data: 'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou'\n",
"Target data: 'irst Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou '\n"
]
}
],
"source": [
"# 将take的参数设为2能看的更清楚\n",
"for input_example, target_example in dataset.take(1):\n",
" print ('Input data: ', repr(''.join(idx2char[input_example.numpy()])))\n",
" print ('Target data:', repr(''.join(idx2char[target_example.numpy()])))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"在训练之前,先简单模拟一下预测First这个单词的过程:比如第一步(step 0),获得输入是19(F),预测值应该是47(i),以此类推。当然,这不是RNN。"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 0\n",
" input: 18 ('F')\n",
" expected output: 47 ('i')\n",
"Step 1\n",
" input: 47 ('i')\n",
" expected output: 56 ('r')\n",
"Step 2\n",
" input: 56 ('r')\n",
" expected output: 57 ('s')\n",
"Step 3\n",
" input: 57 ('s')\n",
" expected output: 58 ('t')\n",
"Step 4\n",
" input: 58 ('t')\n",
" expected output: 1 (' ')\n"
]
}
],
"source": [
"for i, (input_idx, target_idx) in enumerate(zip(input_example[:5], target_example[:5])):\n",
" print(\"Step {:4d}\".format(i))\n",
" print(\" input: {} ({:s})\".format(input_idx, repr(idx2char[input_idx])))\n",
" print(\" expected output: {} ({:s})\".format(target_idx, repr(idx2char[target_idx])))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 使用批次重新构造训练数据\n",
"\n",
"到目前为止,使用了如下的变量来表示文本的不同形态:\n",
"* text: 原始的文本\n",
"* text_as_int:向量化(数字化)的字符串\n",
"* sequences:按照seq_length+1切分的Dataset\n",
"* dataset:将每一个seqences划分为input_text和target_text的Dataset,此时的dataset其实比sequences大了一倍\n",
"\n",
"到这个阶段,我们还需要将dataset中的(input_text,target_text)对进行shuffle处理。注意,这里的shuffle是以seq_length长度的input_text/target_text对为单位的,不是字符级别的shuffle。想一下dataset的数据结构。\n",
"\n",
"另外,还需要进一步对数据进行batch处理以便迭代训练。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"有点奇怪的是,fit方法为什么不能通过已经定义的batch_size自动确定步长?为什么一定要通过一个steps_per_epoch参数呢?steps_per_epoch也是通过batch_size计算出来的啊,按说应该都能够达到目的的。查阅了一下2.0.0-alpha0的文档,**这个限制已经取消了**,参见:https://www.tensorflow.org/alpha/tutorials/text/text_generation"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<DatasetV1Adapter shapes: ((32, 100), (32, 100)), types: (tf.int64, tf.int64)>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Batch size\n",
"# 这里的BATCH_SIZE的单位不是字符,因为此时的dataset是按照\n",
"# ((seq_length,),(seq_length))组织的\n",
"# 这里的32意味着,经过32次迭代,就需要遍历整个dataset,因此每次迭代需要喂入\n",
"# 的数据尺寸如steps_per_epoch所示。\n",
"BATCH_SIZE = 32\n",
"\n",
"# steps_per_epoch说明每次喂入RNN的(input_example,target_example)的个数\n",
"# 使用model.fit时,如果传入的数据集是Dataset对象,必须显式声明steps_per_epoch参数\n",
"# 道理很简单,否则tensorflow不知道以多大的步长循环迭代给定的Dataset。因为传入fit函数\n",
"# 的Dataset只是经过了seq_length分组的input_text和target_text,并没有指定训练时\n",
"# 使用多大的步长来迭代整个Dataset。\n",
"# steps_per_epoch = len(text)//seq_length//BATCH_SIZE\n",
"steps_per_epoch = examples_per_epoch//BATCH_SIZE\n",
"\n",
"# Buffer size to shuffle the dataset\n",
"# (TF data is designed to work with possibly infinite sequences,\n",
"# so it doesn't attempt to shuffle the entire sequence in memory. Instead,\n",
"# it maintains a buffer in which it shuffles elements).\n",
"BUFFER_SIZE = 10000\n",
"\n",
"# \n",
"dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)\n",
"dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 创建模型\n",
"\n",
"模型分为三层:\n",
"1. 嵌入层(layers.Embedding)。关于嵌入的概念可参考:https://tensorflow.google.cn/guide/embedding 。简单的说,嵌入层的作用是将输入(本例是输入字符的索引)映射为一个高维度向量(dense vector),其好处是可以借助于向量的方法,比如欧氏距离或者角度来度量两个向量的相似性。对于文本而言,就是两个词的相似度。\n",
"2. GRU层(Gated Recurrent Unit)\n",
"3. 全链接层"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 设置模型参数,实例化模型\n",
"为了能够在笔记本电脑上运行,特意调小了embedding_dim和rnn_units两个参数"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# Length of the vocabulary in chars\n",
"# 这是输入层和输出层的维度。\n",
"# 每一个字符都需要进行one-hot编码,因此每一个输入都是vocab_size维度的向量\n",
"# 同样的,每一个预测的输出也是vocab_size维度的向量\n",
"vocab_size = len(vocab)\n",
"\n",
"# The embedding dimension \n",
"#embedding_dim = 256\n",
"embedding_dim = 256\n",
"\n",
"# Number of RNN units\n",
"#rnn_units = 1024\n",
"rnn_units = 1024"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"if tf.test.is_gpu_available():\n",
" rnn = tf.keras.layers.CuDNNGRU\n",
"else:\n",
" import functools\n",
" rnn = functools.partial(\n",
" tf.keras.layers.GRU, recurrent_activation='sigmoid')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def build_model(vocab_size, embedding_dim, rnn_units, batch_size):\n",
" model = tf.keras.Sequential([\n",
" tf.keras.layers.Embedding(vocab_size, embedding_dim, \n",
" batch_input_shape=[batch_size, None]),\n",
" # 替换rnn为LSTM\n",
" tf.keras.layers.LSTM(rnn_units,\n",
" return_sequences=True, \n",
" recurrent_initializer='glorot_uniform',\n",
" stateful=True),\n",
" tf.keras.layers.Dense(vocab_size) # 这里不需要激活函数?softmax?\n",
" ])\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"embedding (Embedding) (32, None, 256) 16640 \n",
"_________________________________________________________________\n",
"lstm (LSTM) (32, None, 1024) 5246976 \n",
"_________________________________________________________________\n",
"dense (Dense) (32, None, 65) 66625 \n",
"=================================================================\n",
"Total params: 5,330,241\n",
"Trainable params: 5,330,241\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model = build_model(\n",
" vocab_size = len(vocab), \n",
" embedding_dim=embedding_dim, \n",
" rnn_units=rnn_units, \n",
" batch_size=BATCH_SIZE)\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 先测试一下模型\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(32, 100, 65) # (batch_size, sequence_length, vocab_size)\n"
]
}
],
"source": [
"# input_example_batch是dataset的一个batch,这里是32个seq_length的input_text\n",
"# 由于喂入的数据的shape是(32,seq_length),输出example_batch_prediction的\n",
"# shape自然就是(32,seq_length,65)\n",
"for input_example_batch, target_example_batch in dataset.take(1): \n",
" example_batch_predictions = model(input_example_batch)\n",
" print(example_batch_predictions.shape, \"# (batch_size, sequence_length, vocab_size)\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<font color=\"red\">这是为什么使用random.categorical抽取数据?</font>"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([47, 19, 13, 46, 12, 19, 15, 13, 22, 18, 31, 17, 12, 6, 38, 50, 17,\n",
" 28, 5, 30, 60, 29, 50, 21, 2, 27, 36, 16, 13, 28, 1, 44, 25, 43,\n",
" 62, 59, 15, 29, 51, 25, 54, 10, 11, 58, 44, 48, 63, 48, 55, 26, 4,\n",
" 46, 46, 60, 56, 40, 29, 27, 5, 48, 63, 59, 8, 16, 58, 5, 46, 19,\n",
" 63, 38, 32, 53, 23, 57, 35, 41, 53, 48, 12, 34, 49, 43, 38, 29, 46,\n",
" 42, 8, 1, 21, 39, 27, 25, 41, 18, 41, 6, 4, 51, 54, 44])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 检查第0批数据?\n",
"sampled_indices = tf.random.categorical(example_batch_predictions[0], \n",
" num_samples=1)\n",
"sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()\n",
"sampled_indices"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input: \n",
" \"Provided that, when he's removed, your highness\\nWill take again your queen as yours at first,\\nEven f\"\n",
"Next Char Predictions: \n",
" \"iGAh?GCAJFSE?,ZlEP'RvQlI!OXDAP fMexuCQmMp:;tfjyjqN&hhvrbQO'jyu.Dt'hGyZToKsWcoj?VkeZQhd. IaOMcFc,&mpf\"\n"
]
}
],
"source": [
"print(\"Input: \\n\", repr(\"\".join(idx2char[input_example_batch[0].numpy()])))\n",
"print(\"Next Char Predictions: \\n\", repr(\"\".join(idx2char[sampled_indices])))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 定义优化器和损失函数"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prediction shape: (32, 100, 65) # (batch_size, sequence_length, vocab_size)\n",
"scalar_loss: 4.1746464\n"
]
}
],
"source": [
"def loss(labels, logits):\n",
" return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)\n",
"\n",
"example_batch_loss = loss(target_example_batch, example_batch_predictions)\n",
"print(\"Prediction shape: \", example_batch_predictions.shape, \" # (batch_size, sequence_length, vocab_size)\") \n",
"print(\"scalar_loss: \", example_batch_loss.numpy().mean())"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# Directory where the checkpoints will be saved\n",
"checkpoint_dir = './training_checkpoints'\n",
"# Name of the checkpoint files\n",
"checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")\n",
"\n",
"checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(\n",
" filepath=checkpoint_prefix,\n",
" save_weights_only=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"恢复checkpoint\n",
"如何检测checkoutpoint是否存在?"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"ckpt = tf.train.latest_checkpoint(checkpoint_dir)\n",
"if ckpt != None:\n",
" print(\"load model from checkpoint\")\n",
" model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=BATCH_SIZE)\n",
" model.load_weights(ckpt)\n",
" model.build(tf.TensorShape([1, None]))\n",
" model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"model.compile(\n",
" optimizer = 'adam',\n",
" loss = loss)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"EPOCHS = 3"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/subaochen/anaconda3/envs/tf1-cpu/lib/python3.7/site-packages/tensorflow/python/ops/gradients_impl.py:110: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
" \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"347/348 [============================>.] - ETA: 4s - loss: 2.4359WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7f9988376828>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.\n",
"\n",
"Consider using a TensorFlow optimizer from `tf.train`.\n",
"WARNING:tensorflow:From /home/subaochen/anaconda3/envs/tf1-cpu/lib/python3.7/site-packages/tensorflow/python/keras/engine/network.py:1436: update_checkpoint_state (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.train.CheckpointManager to manage checkpoints rather than manually editing the Checkpoint proto.\n",
"348/348 [==============================] - 1700s 5s/step - loss: 2.4344\n",
"Epoch 2/3\n",
"347/348 [============================>.] - ETA: 4s - loss: 1.7632WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7f9988376828>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.\n",
"\n",
"Consider using a TensorFlow optimizer from `tf.train`.\n",
"348/348 [==============================] - 1555s 4s/step - loss: 1.7627\n",
"Epoch 3/3\n",
"347/348 [============================>.] - ETA: 4s - loss: 1.5562WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7f9988376828>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.\n",
"\n",
"Consider using a TensorFlow optimizer from `tf.train`.\n",
"348/348 [==============================] - 1552s 4s/step - loss: 1.5561\n"
]
}
],
"source": [
"history = model.fit(dataset.repeat(), epochs=EPOCHS, \n",
" steps_per_epoch=steps_per_epoch, \n",
" callbacks=[checkpoint_callback])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 绘制训练图表"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEWCAYAAACJ0YulAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGi5JREFUeJzt3X2YHWWZ5/HvTRIJkUAgieIkhKA4DiQSDC2CZJeIu8rL4LsrEMBh8Yq4rMZddckCvuHkWpUdxYAOZhUYlxZkFR1l1chIHAYZE5OQEEgmhgESe4lDJ44kEFE6ufePqi6a0C+n033O6XR/P9fVV1c99dSpuyuV/nXVU6dOZCaSJAEc0OwCJElDh6EgSaoYCpKkiqEgSaoYCpKkiqEgSaoYChrxImJURDwVEdMGs+8+1PGXEXHzYL+u1B+jm12A1F8R8VSX2XHAH4Dd5fz7M7O1P6+XmbuBgwe7r7Q/MhS038nM6pdyRDwGvC8z/66n/hExOjM7GlGbtL/z8pGGnfIyzLci4taI2AlcEBGnRMQvIuJ3EbE1IhZHxJiy/+iIyIiYXs7fUi7/UUTsjIh/jIij+9u3XH5mRPwqIp6MiOsi4ucR8Rc1/hxvi4iHyprvjohXdVl2RUQ8HhE7IuKfImJu2X5yRKwu2/8lIq4ZhF2qEcRQ0HD1duCbwKHAt4AOYAEwCTgVOAN4fy/rnw98HDgc2AJ8pr99I+IlwO3Ax8rtPgqcVEvxEXEscAvwQWAy8HfADyJiTETMKGufnZmHAGeW2wW4DrimbD8G+HYt25M6GQoaru7NzB9k5p7M/H1m/jIzl2dmR2Y+AiwBTutl/W9n5srMfBZoBU7Yh75/DqzJzL8tl30R2FZj/ecC38/Mu8t1PwscAryOIuDGAjPKS2OPlj8TwLPAKyNiYmbuzMzlNW5PAgwFDV+/7joTEX8WEf83In4TETuAqyn+eu/Jb7pM76L3weWe+v5J1zqyePpkWw21d667ucu6e8p1p2TmRuAjFD/DE+VlsiPKrhcDxwEbI2JFRJxV4/YkwFDQ8LX343+/CjwIHFNeWvkEEHWuYSswtXMmIgKYUuO6jwNHdVn3gPK1/h9AZt6SmacCRwOjgP9Rtm/MzHOBlwB/BXwnIsYO/EfRSGEoaKQYDzwJPF1er+9tPGGw3AnMjohzImI0xZjG5BrXvR14S0TMLQfEPwbsBJZHxLER8YaIOBD4ffm1GyAiLoyISeWZxZMU4bhncH8sDWeGgkaKjwDvpfjF+lWKwee6ysx/Ad4DfAHYDrwCuJ/ifRV9rfsQRb1/DbRTDIy/pRxfOBD4PMX4xG+Aw4CrylXPAjaUd139T+A9mfnHQfyxNMyFH7IjNUZEjKK4LPSuzPyHZtcjdcczBamOIuKMiDi0vNTzcYo7h1Y0uSypR4aCVF9zgEcoLvWcAbwtM/u8fCQ1i5ePJEkVzxQkSZX97oF4kyZNyunTpze7DEnar6xatWpbZvZ5S/R+FwrTp09n5cqVzS5DkvYrEbG5715ePpIkdWEoSJIqhoIkqbLfjSlIGnqeffZZ2traeOaZZ5pdyog3duxYpk6dypgxY/ZpfUNB0oC1tbUxfvx4pk+fTvEwWDVDZrJ9+3ba2to4+uij+16hGyPi8lFrK0yfDgccUHxv7dfHukvqyzPPPMPEiRMNhCaLCCZOnDigM7Zhf6bQ2grz58OuXcX85s3FPMC8ec2rSxpuDIShYaD/DsP+TOHKK58LhE67dhXtkqTnG/ahsGVL/9ol7X+2b9/OCSecwAknnMARRxzBlClTqvk//rG2j5O4+OKL2bhxY699vvzlL9M6SNef58yZw5o1awbltQbTsL98NG1accmou3ZJzdHaWpytb9lS/F9ctGhgl3MnTpxY/YL91Kc+xcEHH8xHP/rR5/XJTDKTAw7o/m/hm266qc/tXHbZZfte5H5i2J8pLFoE48Y9v23cuKJdUuN1jvNt3gyZz43z1eMGkIcffpiZM2dy6aWXMnv2bLZu3cr8+fNpaWlhxowZXH311VXfzr/cOzo6mDBhAgsXLmTWrFmccsopPPHEEwBcddVVXHvttVX/hQsXctJJJ/GqV72K++67D4Cnn36ad77zncyaNYvzzjuPlpaWPs8IbrnlFl796lczc+ZMrrjiCgA6Ojq48MILq/bFixcD8MUvfpHjjjuOWbNmccEFFwz6Phv2oTBvHixZAkcdBRHF9yVLHGSWmqXR43zr16/nkksu4f7772fKlCl89rOfZeXKlaxdu5a77rqL9evXv2CdJ598ktNOO421a9dyyimncOONN3b72pnJihUruOaaa6qAue666zjiiCNYu3YtCxcu5P777++1vra2Nq666iqWLVvG/fffz89//nPuvPNOVq1axbZt21i3bh0PPvggF110EQCf//znWbNmDWvXruX6668f4N55oWEfClAEwGOPwZ49xXcDQWqeRo/zveIVr+C1r31tNX/rrbcye/ZsZs+ezYYNG7oNhYMOOogzzzwTgBNPPJHHHnus29d+xzve8YI+9957L+eeey4As2bNYsaMGb3Wt3z5ck4//XQmTZrEmDFjOP/887nnnns45phj2LhxIwsWLGDp0qUceuihAMyYMYMLLriA1tbWfX6DWm/qFgoRcWRELIuIDRHxUEQs6KXvayNid0S8q171SBoaehrPq9c434tf/OJqetOmTXzpS1/i7rvv5oEHHuCMM87o9p7+F73oRdX0qFGj6Ojo6Pa1DzzwwBf06e8Hl/XUf+LEiTzwwAPMmTOHxYsX8/73vx+ApUuXcumll7JixQpaWlrYvXt3v7bXl3qeKXQAH8nMY4GTgcsi4ri9O5UfZv45YGkda5E0RDRznG/Hjh2MHz+eQw45hK1bt7J06eD/2pkzZw633347AOvWrev2TKSrk08+mWXLlrF9+3Y6Ojq47bbbOO2002hvbyczefe7382nP/1pVq9eze7du2lra+P000/nmmuuob29nV17X4sboLrdfZSZW4Gt5fTOiNgATAH23kMfBL4DvBZJw17n5dvBvPuoVrNnz+a4445j5syZvPzlL+fUU08d9G188IMf5KKLLuL4449n9uzZzJw5s7r0052pU6dy9dVXM3fuXDKTc845h7PPPpvVq1dzySWXkJlEBJ/73Ofo6Ojg/PPPZ+fOnezZs4fLL7+c8ePHD2r9DfmM5oiYDtwDzMzMHV3apwDfBE4Hvg7cmZnf7mb9+cB8gGnTpp24ubt7TCU1zYYNGzj22GObXcaQ0NHRQUdHB2PHjmXTpk286U1vYtOmTYwe3bh3AHT37xERqzKzpa91615lRBxMcSbw4a6BULoWuDwzd/f21uzMXAIsAWhpaal/iknSPnrqqad44xvfSEdHB5nJV7/61YYGwkDVtdKIGEMRCK2ZeUc3XVqA28pAmAScFREdmfm9etYlSfUyYcIEVq1a1ewy9lndQiGK3/RfBzZk5he665OZR3fpfzPF5SMDQdoPdV77VnMNdEignmcKpwIXAusiovPtfFcA0wAy84Y6bltSA40dO5bt27f7+Owm6/w8hbFjx+7za9Tz7qN7gZqPjsz8i3rVIqm+pk6dSltbG+3t7c0uZcTr/OS1fbX/jH5IGrLGjBmzz5/0paFlRDzmQpJUG0NBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJlbqFQkQcGRHLImJDRDwUEQu66TMvIh4ov+6LiFn1qkeS1LfRdXztDuAjmbk6IsYDqyLirsxc36XPo8BpmfmvEXEmsAR4XR1rkiT1om6hkJlbga3l9M6I2ABMAdZ36XNfl1V+AUytVz2SpL41ZEwhIqYDrwGW99LtEuBHPaw/PyJWRsTK9vb2wS9QkgQ0IBQi4mDgO8CHM3NHD33eQBEKl3e3PDOXZGZLZrZMnjy5fsVK0ghXzzEFImIMRSC0ZuYdPfQ5HvgacGZmbq9nPZKk3tXz7qMAvg5syMwv9NBnGnAHcGFm/qpetUiSalPPM4VTgQuBdRGxpmy7ApgGkJk3AJ8AJgJfKTKEjsxsqWNNkqRe1PPuo3uB6KPP+4D31asGSVL/+I5mSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVQwFSVLFUJAkVWoKhYh4RUQcWE7PjYgPRcSE+pYmSWq0Ws8UvgPsjohjgK8DRwPfrFtVkqSmqDUU9mRmB/B24NrM/C/Ay+pXliSpGWoNhWcj4jzgvcCdZduY+pQkSWqWWkPhYuAUYFFmPhoRRwO39LZCRBwZEcsiYkNEPBQRC7rpExGxOCIejogHImJ2/38ESdJgGV1Lp8xcD3wIICIOA8Zn5mf7WK0D+Ehmro6I8cCqiLirfK1OZwKvLL9eB/x1+V2S1AS13n30s4g4JCIOB9YCN0XEF3pbJzO3ZubqcnonsAGYsle3twLfyMIvgAkR4ViFJDVJrZePDs3MHcA7gJsy80Tg39W6kYiYDrwGWL7XoinAr7vMt/HC4CAi5kfEyohY2d7eXutmJUn9VGsojC7/gv8PPDfQXJOIOJjiltYPl8HyvMXdrJIvaMhckpktmdkyefLk/mxektQPtYbC1cBS4J8z85cR8XJgU18rRcQYikBozcw7uunSBhzZZX4q8HiNNUmSBllNoZCZ/yczj8/MD5Tzj2TmO3tbJyKC4o1uGzKzp/GH7wMXlXchnQw8mZlb+1G/JGkQ1XT3UURMBa4DTqW4vHMvsCAz23pZ7VTgQmBdRKwp264ApgFk5g3AD4GzgIeBXRS3vkqSmqSmUABuonisxbvL+QvKtn/f0wqZeS/djxl07ZPAZTXWIEmqs1rHFCZn5k2Z2VF+3Qw44itJw0ytobAtIi6IiFHl1wXA9noWJklqvFpD4T9S3I76G2Ar8C68/i9Jw06tdx9tycy3ZObkzHxJZr6N4o1skqRhZCCfvPZfB60KSdKQMJBQ6PXOIknS/mcgofCCx1FIkvZvvb5PISJ20v0v/wAOqktFkqSm6TUUMnN8owqRJDXfQC4fSZKGGUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJFUNBklQxFCRJlbqFQkTcGBFPRMSDPSw/NCJ+EBFrI+KhiLi4XrVIkmpTzzOFm4Ezell+GbA+M2cBc4G/iogX1bEeSVIf6hYKmXkP8NveugDjIyKAg8u+HfWqR5LUt2aOKVwPHAs8DqwDFmTmnu46RsT8iFgZESvb29sbWaMkjSjNDIU3A2uAPwFOAK6PiEO665iZSzKzJTNbJk+e3MgaJWlEaWYoXAzckYWHgUeBP2tiPZI04jUzFLYAbwSIiJcCrwIeaWI9kjTija7XC0fErRR3FU2KiDbgk8AYgMy8AfgMcHNErAMCuDwzt9WrHklS3+oWCpl5Xh/LHwfeVK/tS43U2gpXXglbtsC0abBoEcyb1+yqpP6rWyhII0VrK8yfD7t2FfObNxfzYDBo/+NjLqQBuvLK5wKh065dRbu0vzEUpAHasqV/7dJQZihIAzRtWv/apaHMUJAGaNEiGDfu+W3jxhXt0v7GUJAGaN48WLIEjjoKIorvS5Y4yKz9k3cfSYNg3jxDQMODZwqSpIqhIEmqGAqSpIqhIEmqGAqSpIqhIEmqGAqSpIqhIEmqGAqSpIqhIEmqGAqSpIqhIEmqGAqSpIqhIEmqGAqSpIqhIEmqGAqSpIqhIEmqGAqSpIqhIEmqGAqSpIqhIEmqGAqSpIqhIEmqGAqSpIqhIEmqGAqSpIqhIEmqGAqSpErdQiEiboyIJyLiwV76zI2INRHxUET8fb1qkSTVpp5nCjcDZ/S0MCImAF8B3pKZM4B317EWSVIN6hYKmXkP8NteupwP3JGZW8r+T9SrFklSbZo5pvCnwGER8bOIWBURF/XUMSLmR8TKiFjZ3t7ewBIlaWRpZiiMBk4EzgbeDHw8Iv60u46ZuSQzWzKzZfLkyY2sUZJGlNFN3HYbsC0znwaejoh7gFnAr5pYkySNaM08U/hb4N9ExOiIGAe8DtjQxHokacSr25lCRNwKzAUmRUQb8ElgDEBm3pCZGyLix8ADwB7ga5nZ4+2rkqT6q1soZOZ5NfS5BrimXjVIkvrHdzRLkiqGgiQNca2tMH06HHBA8b21tX7baubdR5KkPrS2wvz5sGtXMb95czEPMG/e4G/PMwVJGsKuvPK5QOi0a1fRXg+GgiQNYVu29K99oAwFSRrCpk3rX/tAGQqSNIQtWgTjxj2/bdy4or0eDAVJGsLmzYMlS+CooyCi+L5kSX0GmcG7jyRpyJs3r34hsDfPFCRJFUNBklQxFCRJFUNBklQxFCRJlcjMZtfQLxHRDmzex9UnAdsGsZzBMlTrgqFbm3X1j3X1z3Cs66jM7PPzjPe7UBiIiFiZmS3NrmNvQ7UuGLq1WVf/WFf/jOS6vHwkSaoYCpKkykgLhSXNLqAHQ7UuGLq1WVf/WFf/jNi6RtSYgiSpdyPtTEGS1AtDQZJUGRahEBE3RsQTEfFgD8sjIhZHxMMR8UBEzO6y7L0Rsan8em+D65pX1vNARNwXEbO6LHssItZFxJqIWDmYddVY29yIeLLc/pqI+ESXZWdExMZyfy5sYE0f61LPgxGxOyIOL5fVbX9FxJERsSwiNkTEQxGxoJs+DT/Gaqyr4cdYjXU14/iqpa5mHWNjI2JFRKwta/t0N30OjIhvlftleURM77Lsv5ftGyPizQMqJjP3+y/g3wKzgQd7WH4W8CMggJOB5WX74cAj5ffDyunDGljX6zu3B5zZWVc5/xgwqYn7bC5wZzfto4B/Bl4OvAhYCxzXiJr26nsOcHcj9hfwMmB2OT0e+NXeP3MzjrEa62r4MVZjXc04vvqsq4nHWAAHl9NjgOXAyXv1+U/ADeX0ucC3yunjyv10IHB0uf9G7Wstw+JMITPvAX7bS5e3At/Iwi+ACRHxMuDNwF2Z+dvM/FfgLuCMRtWVmfeV2wX4BTB1sLbdlxr2WU9OAh7OzEcy84/AbRT7t9E1nQfcOhjb7Utmbs3M1eX0TmADMGWvbg0/xmqpqxnHWI37qyf1PL76W1cjj7HMzKfK2THl1953Ab0V+Jty+tvAGyMiyvbbMvMPmfko8DDFftwnwyIUajAF+HWX+bayraf2ZriE4i/NTgn8JCJWRcT8JtV0Snk6+6OImFG2NX2fRcQ4il+s3+nS3JD9VZ6yv4biL7mumnqM9VJXVw0/xvqoq2nHV1/7qxnHWESMiog1wBMUf0j0eIxlZgfwJDCRQd5nI+WT16KbtuylvaEi4g0U/2HndGk+NTMfj4iXAHdFxD+Vf0k3ymqKZ6U8FRFnAd8DXsnQ2GfnAD/PzK5nFXXfXxFxMMUviQ9n5o69F3ezSkOOsT7q6uzT8GOsj7qadnzVsr9owjGWmbuBEyJiAvDdiJiZmV3H1xpyjI2UM4U24Mgu81OBx3tpb5iIOB74GvDWzNze2Z6Zj5ffnwC+ywBOB/dFZu7oPJ3NzB8CYyJiEkNgn1FcT33eaX2991dEjKH4RdKamXd006Upx1gNdTXlGOurrmYdX7Xsr1LDj7Eu2/kd8DNeeJmx2jcRMRo4lOJy6+Dus8EeMGnWFzCdngdNz+b5g4AryvbDgUcpBgAPK6cPb2Bd0yiu/71+r/YXA+O7TN8HnNHgfXYEz7258SRgS7n/RlMMlh7NcwOBMxpRU7m88z/Cixu1v8qf+xvAtb30afgxVmNdDT/Gaqyr4cdXLXU18RibDEwopw8C/gH48736XMbzB5pvL6dn8PyB5kcYwEDzsLh8FBG3UtzNMCki2oBPUgzUkJk3AD+kuDvkYWAXcHG57LcR8Rngl+VLXZ3PP12sd12foLgm+JVivIiOLJ6A+FKK00co/pN8MzN/PFh11Vjbu4APREQH8Hvg3CyOwI6I+M/AUoo7RW7MzIcaVBPA24GfZObTXVat9/46FbgQWFde8wW4guIXbjOPsVrqasYxVktdDT++aqwLmnOMvQz4m4gYRXEF5/bMvDMirgZWZub3ga8D/zsiHqYIrXPLuh+KiNuB9UAHcFkWl6L2iY+5kCRVRsqYgiSpBoaCJKliKEiSKoaCJKliKEiSKoaCVCqfiLmmy9dgPqFzevTw9FdpKBkW71OQBsnvM/OEZhchNZNnClIfyufof6583v2KiDimbD8qIn4axWcV/DQippXtL42I75YPe1sbEa8vX2pURPyv8nn5P4mIg8r+H4qI9eXr3NakH1MCDAWpq4P2unz0ni7LdmTmScD1wLVl2/UUj8s+HmgFFpfti4G/z8xZFJ8P0fmO3FcCX87MGcDvgHeW7QuB15Svc2m9fjipFr6jWSpFxFOZeXA37Y8Bp2fmI+UD1X6TmRMjYhvwssx8tmzfmpmTIqIdmJqZf+jyGtMpHof8ynL+cmBMZv5lRPwYeIriSaHfy+eeqy81nGcKUm2yh+me+nTnD12md/PcmN7ZwJeBE4FV5RMwpaYwFKTavKfL938sp++jfCgZMA+4t5z+KfABqD445ZCeXjQiDgCOzMxlwH8DJgAvOFuRGsW/SKTnHNTl6ZkAP87MzttSD4yI5RR/SJ1Xtn0IuDEiPga0Uz4ZFVgALImISyjOCD4AbO1hm6OAWyLiUIpHO38xi+fpS03hmILUh3JMoSUztzW7FqnevHwkSap4piBJqnimIEmqGAqSpIqhIEmqGAqSpIqhIEmq/H+EPtsixmt7dgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"history_dict = history.history\n",
"history_dict.keys()\n",
"loss=history_dict['loss']\n",
"epochs = range(1, len(loss) + 1)\n",
"# \"bo\" is for \"blue dot\"\n",
"plt.plot(epochs, loss, 'bo', label='Training loss')\n",
"plt.title('Training loss')\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('Loss')\n",
"plt.legend()\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 产生文本"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 恢复到最新的checkpoint\n",
"\n",
"这个步骤是不是应该放在训练之前,以便积累训练的成果?"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"embedding_1 (Embedding) (1, None, 256) 16640 \n",
"_________________________________________________________________\n",
"lstm_1 (LSTM) (1, None, 1024) 5246976 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (1, None, 65) 66625 \n",
"=================================================================\n",
"Total params: 5,330,241\n",
"Trainable params: 5,330,241\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"ckpt = tf.train.latest_checkpoint(checkpoint_dir)\n",
"model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)\n",
"model.load_weights(ckpt)\n",
"model.build(tf.TensorShape([1, None]))\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 进行预测\n",
"\n",
"model可以接受任意长度的字符串作为参数。实际上,无论多长的字符串,model都是需要一个一个进行处理的,最终给出的是每个输入字符对应的预测字符。参考下图了解shape在各个过程的变化(出处:https://www.tensorflow.org/tutorials/sequences/text_generation):\n",
"![](http://softlab.sdut.edu.cn/blog/subaochen/wp-content/uploads/sites/4/2019/05/text_generation_training.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 如何观察预测结果?\n",
"可以设置num_generate为一个**小的数字**,比如3,然后在后面的三个循环中,逐步打印出input_eval, prediction_id等的值,注意观察在不同的阶段各个向量的**维度**和**数值**的变化。"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"First: tooke to\n",
"fafour every heady in reveck'd, done,--\n",
"Me then you arserven that daughter bloward hit belleved,\n",
"Res in the mat on off Yor what me, I'th conrece\n",
"That for in more -lary\n",
"Good-dary was face-time every and abscair?\n",
"\n",
"Shew!:\n",
"Which I:\n",
"And death, beat me Manishave: go figh,\n",
"Ereformive our very ententy their him;\n",
"Nor poward's\n",
"So this murderous bogh, becomes him his turn is this law here,\n",
"You hast thourd broughful by us, stilp, to our master? Here't thoubled the veors are\n",
"I receedn his own most, but hear it mouth\n",
"The founty bane caute to at id!\n",
"My chave if a present have queence now or sweet\n",
"Engies by you may such alisme, to Sear'd at\n",
"the capest, or buh of our purpil' their cape,\n",
"This dasger'd upon my ty, see to mean your maughter. I wisl, a get it is:\n",
"This sorrewainury; what he's fain of York,\n",
"Unlevis wibliently with his countiss:\n",
"Master than my heaven own gown where an dook. When queen by\n",
"cipy to true; thou to my Anglainors.\n",
"\n",
"VOLUMNIA:\n",
"Offort; sir; if you should the gruely smiles:\n",
"I \n"
]
}
],
"source": [
"def generate_text(model, start_string):\n",
" # Evaluation step (generating text using the learned model)\n",
"\n",
" # Number of characters to generate\n",
" num_generate = 1000\n",
"\n",
" # Converting our start string to numbers (vectorizing) \n",
" input_eval = [char2idx[s] for s in start_string]\n",
" # 构造维度合适的输入数据:[batch_size,seq_length]\n",
" # 这里batch_size=1,因此只需要将start_string扩展一维即可\n",
" input_eval = tf.expand_dims(input_eval, 0)\n",
"\n",
" # Empty string to store our results\n",
" text_generated = []\n",
"\n",
" # Low temperatures results in more predictable text.\n",
" # Higher temperatures results in more surprising text.\n",
" # Experiment to find the best setting.\n",
" temperature = 1.\n",
"\n",
" # Here batch size == 1\n",
" model.reset_states()\n",
" for i in range(num_generate):\n",
" # why not call model.predict()?\n",
" predictions = model(input_eval)\n",
" #print(\"predictions.shape:\",predictions.shape)\n",
" \n",
" # remove the batch dimension\n",
" predictions = tf.squeeze(predictions, 0)\n",
"\n",
" # using a multinomial distribution to predict the word returned by the model\n",
" predictions = predictions / temperature\n",
" predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()\n",
" #print(\"predicted_id:\",predicted_id)\n",
" \n",
" # We pass the predicted word as the next input to the model\n",
" # along with the previous hidden state\n",
" input_eval = tf.expand_dims([predicted_id], 0)\n",
" \n",
" text_generated.append(idx2char[predicted_id])\n",
"\n",
" return (start_string + ''.join(text_generated))\n",
"\n",
"print(generate_text(model, start_string=u\"First:\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 后记\n",
"\n",
"从这个简单的RNN的示例中,可以学到:\n",
"* RNN很强大,只是简单的一层LSTM(RNN)和为数不多的几次迭代训练,RNN就能够学会简单的语法结构甚至比较短的单词构成。馈入更多的训练样本和更多的迭代次数,RNN应该能够学会更多的语法特征和更多的词汇。\n",
"* RNN的训练很消耗资源。\n",
"* RNN能够解决很多问题,这里演示的是many to many类型的。\n",
"* RNN并不限制输入数据的大小,也就是说,RNN的input_example可以是任意长度的。\n",
"* 无论一次馈入多少数据,显然RNN也是一个一个进行计算和处理的;只是,在定义了`batch_size`的前提下,RNN会按照指定的`batch_size`一次汇总给出计算结果,这就是为什么输入数据要组织为`(batch_size, seq_length)`的原因:只有在输入的时候确定了`batch_size`,输出的时候才能够按照这个batch_size汇总结果为`(batch_size, vcab_size)`。\n",
"* 使用`tf.random.categorical`获得预测结果:预测结果显然是按照batch_size给出的概率分布,可以通过`tf.random.categorical`函数方便的获得最大概率项的索引,即预测值。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment