Skip to content

Instantly share code, notes, and snippets.

@p-baleine
Last active September 8, 2017 22:53
Show Gist options
  • Save p-baleine/6a710a591549e66b1146d182e3baeef9 to your computer and use it in GitHub Desktop.
Save p-baleine/6a710a591549e66b1146d182e3baeef9 to your computer and use it in GitHub Desktop.
seq2seq new API
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import os\n",
"import six\n",
"import tensorflow as tf\n",
"import time\n",
"\n",
"from tensorflow.python.layers import core as core_layers\n",
"from tensorflow.python.util import nest"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"SPECIAL_SYMBOLS = [\"<PAD>\", \"<GO>\", \"<EOS>\", \"<UNK>\"]\n",
"\n",
"def tokenizer(raw_sentences):\n",
" for sentence in raw_sentences:\n",
" yield sentence.split()\n",
"\n",
"class Vocabulary(tf.contrib.learn.preprocessing.CategoricalVocabulary):\n",
" def __init__(self):\n",
" super(Vocabulary, self).__init__(unknown_token=\"<UNK>\")\n",
" self._mapping = dict((w, id) for id, w in enumerate(SPECIAL_SYMBOLS))\n",
" self._reverse_mapping = SPECIAL_SYMBOLS[:]\n",
" \n",
" def get(self, category):\n",
" if category not in self._mapping and self._freeze:\n",
" return SPECIAL_SYMBOLS.index(\"<UNK>\")\n",
" else:\n",
" return super(Vocabulary, self).get(category)\n",
"\n",
" def trim(self, min_frequency, max_frequency=-1):\n",
" self._freq = sorted(\n",
" sorted(\n",
" six.iteritems(self._freq),\n",
" key=lambda x: (isinstance(x[0], str), x[0])),\n",
" key=lambda x: x[1],\n",
" reverse=True)\n",
" self._mapping = dict((w, id) for id, w in enumerate(SPECIAL_SYMBOLS))\n",
" self._reverse_mapping = SPECIAL_SYMBOLS[:]\n",
" idx = 4 # after special symbols\n",
" for category, count in self._freq:\n",
" if max_frequency > 0 and count >= max_frequency:\n",
" continue\n",
" if count <= min_frequency:\n",
" break\n",
" self._mapping[category] = idx\n",
" idx += 1\n",
" self._reverse_mapping.append(category)\n",
" self._freq = dict(self._freq[:idx - 1])\n",
" \n",
"class NoPadVocabularyProcessor(tf.contrib.learn.preprocessing.VocabularyProcessor):\n",
" \"\"\"パディングしない版\"\"\"\n",
" def transform(self, raw_documents):\n",
" for tokens in self._tokenizer(raw_documents):\n",
" word_ids = np.zeros(len(tokens), np.int64)\n",
" for idx, token in enumerate(tokens):\n",
" if idx >= self.max_document_length:\n",
" break\n",
" word_ids[idx] = self.vocabulary_.get(token)\n",
" yield word_ids"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"合成文章([ここ](http://www.cems.uwe.ac.uk/~cjwallac/apps/tools/bnf/bnf2sent.cgi)で作った)で試す。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sally madly loves joe \r\n",
"pat loves joe \r\n",
"fred hates sally \r\n"
]
}
],
"source": [
"! head -n 3 truely_madly_deeply.txt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SALLY LOVES BILL \r\n",
"FRED TRULY DEEPLY MADLY DEEPLY LOVES BILL \r\n",
"PAT TRULY LOVES FRED \r\n"
]
}
],
"source": [
"# dstは大文字版、辞書のIDが異なるものに成るようにシャッフルしておく\n",
"! cat truely_madly_deeply.txt | awk '{print toupper($0)}' | shuf > truely_madly_deeply_upper.txt\n",
"! head -n 3 truely_madly_deeply_upper.txt"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"with open(\"truely_madly_deeply.txt\") as f1,\\\n",
"open(\"truely_madly_deeply_upper.txt\") as f2:\n",
" src_processor = NoPadVocabularyProcessor(\n",
" max_document_length=100,\n",
" vocabulary=Vocabulary(),\n",
" tokenizer_fn=tokenizer).fit(f1)\n",
" dst_processor = NoPadVocabularyProcessor(\n",
" max_document_length=100,\n",
" vocabulary=Vocabulary(),\n",
" tokenizer_fn=tokenizer).fit(f2)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['sally madly loves joe', 'pat loves joe']\n",
"['SALLY MADLY LOVES JOE', 'PAT LOVES JOE']\n"
]
}
],
"source": [
"print(list(src_processor.reverse(\n",
" src_processor.transform([\"sally madly loves joe\", \"pat loves joe\"]))))\n",
"print(list(dst_processor.reverse(\n",
" dst_processor.transform([x.upper() for x in [\"sally madly loves joe\", \"pat loves joe\"]]))))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def make_example(src_sentence_ids, dst_sentence_ids):\n",
" \"\"\"seq2seq用のExampleを作って返す\"\"\"\n",
" ex = tf.train.SequenceExample()\n",
" \n",
" target_ids = np.zeros(dst_sentence_ids.size + 1, dtype=np.int64)\n",
" target_ids[:-1] = dst_sentence_ids\n",
" target_ids[-1] = SPECIAL_SYMBOLS.index(\"<EOS>\")\n",
" _dst_sentence_ids = np.zeros(dst_sentence_ids.size + 1, dtype=np.int64)\n",
" _dst_sentence_ids[0] = SPECIAL_SYMBOLS.index(\"<GO>\")\n",
" _dst_sentence_ids[1:] = dst_sentence_ids\n",
" \n",
" ex.context.feature[\"src_sequence_length\"].int64_list.value.append(len(src_sentence_ids))\n",
" ex.context.feature[\"dst_sequence_length\"].int64_list.value.append(len(_dst_sentence_ids))\n",
"\n",
" fl_src_sentence_ids = ex.feature_lists.feature_list[\"src_sentence_ids\"]\n",
" fl_dst_sentence_ids = ex.feature_lists.feature_list[\"dst_sentence_ids\"]\n",
" fl_target_ids = ex.feature_lists.feature_list[\"target_ids\"]\n",
" \n",
" for id in src_sentence_ids:\n",
" fl_src_sentence_ids.feature.add().int64_list.value.append(id)\n",
" for id in _dst_sentence_ids:\n",
" fl_dst_sentence_ids.feature.add().int64_list.value.append(id)\n",
" for id in target_ids:\n",
" fl_target_ids.feature.add().int64_list.value.append(id)\n",
" \n",
" return ex\n",
"\n",
"def example_features():\n",
" \"\"\"seq2seqのfeature\"\"\"\n",
" context_features = {\n",
" \"src_sequence_length\": tf.FixedLenFeature([], dtype=tf.int64),\n",
" \"dst_sequence_length\": tf.FixedLenFeature([], dtype=tf.int64),\n",
" }\n",
"\n",
" sequence_features = {\n",
" \"src_sentence_ids\": tf.FixedLenSequenceFeature([], dtype=tf.int64),\n",
" \"dst_sentence_ids\": tf.FixedLenSequenceFeature([], dtype=tf.int64),\n",
" \"target_ids\": tf.FixedLenSequenceFeature([], dtype=tf.int64),\n",
" }\n",
"\n",
" return context_features, sequence_features"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"with open(\"truely_madly_deeply.txt\") as f1,\\\n",
"open(\"truely_madly_deeply.tfrecords\", \"wb\") as f2:\n",
" src_sentences = f1.readlines()\n",
" dst_sentences = [x.upper() for x in src_sentences] # srcと1対1で対応\n",
" src_sentence_ids_list = list(src_processor.transform(src_sentences))\n",
" dst_sentence_ids_list = list(dst_processor.transform(dst_sentences))\n",
" writer = tf.python_io.TFRecordWriter(f2.name)\n",
" for src_sentence_ids, dst_sentence_ids in zip(src_sentence_ids_list,\\\n",
" dst_sentence_ids_list):\n",
" ex = make_example(src_sentence_ids, dst_sentence_ids)\n",
" writer.write(ex.SerializeToString())\n",
" writer.close()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def inputs(batch_size, filename, num_epochs):\n",
" \"\"\"`batch_size`のインプットを返すinput pipeline\"\"\"\n",
" context_features, sequence_features = example_features()\n",
"\n",
" filename_queue = tf.train.string_input_producer([filename],\n",
" num_epochs=num_epochs)\n",
" reader = tf.TFRecordReader()\n",
" _, serialized_ex = reader.read(filename_queue)\n",
" context_parsed, sequence_parsed = tf.parse_single_sequence_example(\n",
" serialized=serialized_ex,\n",
" context_features=context_features,\n",
" sequence_features=sequence_features)\n",
" \n",
" # tf.train.shuffle_batchがdynamic_padをサポートしていないので\n",
" # 自前でshuffleする\n",
" # https://github.com/tensorflow/tensorflow/issues/5147#issuecomment-271086206\n",
" min_after_dequeue = 10000\n",
" capacity = min_after_dequeue + 3 * batch_size\n",
" inputs = {\n",
" **{\"encoder_input\": sequence_parsed[\"src_sentence_ids\"],\n",
" \"decoder_input\": sequence_parsed[\"dst_sentence_ids\"],\n",
" \"targets\": sequence_parsed[\"target_ids\"]},\n",
" **{\"encoder_sequence_length\": context_parsed[\"src_sequence_length\"],\n",
" \"decoder_sequence_length\": context_parsed[\"dst_sequence_length\"]}}\n",
" dtypes = [x.dtype for _, x in inputs.items()]\n",
" shapes = [x.get_shape() for _, x in inputs.items()]\n",
" queue = tf.RandomShuffleQueue(capacity, min_after_dequeue, dtypes,\n",
" names=list(inputs.keys()))\n",
" enqueue_op = queue.enqueue(inputs)\n",
" qr = tf.train.QueueRunner(queue, [enqueue_op])\n",
" tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, qr)\n",
" inputs = queue.dequeue()\n",
" for (name, tensor), shape in zip(inputs.items(), shapes):\n",
" tensor.set_shape(shape)\n",
" \n",
" batched = tf.train.batch(\n",
" tensors=inputs,\n",
" batch_size=batch_size,\n",
" dynamic_pad=True)\n",
" \n",
" return dict((k, tf.to_int32(t)) for k, t in batched.items())\n",
"\n",
"def inferring_inputs():\n",
" return {\n",
" \"encoder_input\": tf.placeholder(tf.int32, [None, None],\n",
" name=\"encoder_input\"),\n",
" \"encoder_sequence_length\": tf.placeholder(tf.int32, [None],\n",
" name=\"encoder_sequence_length\"),\n",
" \"decoder_input\": tf.placeholder(tf.int32, [None, None],\n",
" name=\"decoder_input\"),\n",
" \"decoder_sequence_length\": tf.placeholder(tf.int32, [None],\n",
" name=\"decoder_sequence_length\"),\n",
" \"targets\": tf.placeholder(tf.int32, [None, None],\n",
" name=\"targets\")\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def encode(cell, encoder_emb_input, sequence_length):\n",
" return tf.nn.dynamic_rnn(\n",
" cell,\n",
" encoder_emb_input,\n",
" sequence_length,\n",
" dtype=tf.float32)\n",
"\n",
"def decode(cell, helper, state, dec_vocab_size):\n",
" output_layer = core_layers.Dense(dec_vocab_size)\n",
" decoder = tf.contrib.seq2seq.BasicDecoder(\n",
" cell=cell,\n",
" helper=helper,\n",
" initial_state=state,\n",
" output_layer=output_layer)\n",
" return tf.contrib.seq2seq.dynamic_decode(\n",
" decoder=decoder,\n",
" impute_finished=True,\n",
" maximum_iterations=150)\n",
"\n",
"def loss_fn(logits, targets, weights):\n",
" return tf.contrib.seq2seq.sequence_loss(logits, targets, weights)\n",
"\n",
"def train(loss, learning_rate=0.01, max_gradients_norm=5.0):\n",
" global_step = tf.train.get_or_create_global_step()\n",
" optimizer = tf.train.AdamOptimizer(learning_rate)\n",
" params = tf.trainable_variables()\n",
" gradients = tf.gradients(loss, params)\n",
" clipped_gradients, _ = tf.clip_by_global_norm(gradients, max_gradients_norm)\n",
" update = optimizer.apply_gradients(\n",
" zip(clipped_gradients, params),\n",
" global_step=global_step)\n",
" return update, global_step\n",
"\n",
"def training_helper(decoder_emb_input, decoder_sequence_length):\n",
" return tf.contrib.seq2seq.TrainingHelper(\n",
" decoder_emb_input,\n",
" decoder_sequence_length)\n",
"\n",
"def infer_helper(dec_embeddings, batch_size, go_id, eos_id):\n",
" return tf.contrib.seq2seq.GreedyEmbeddingHelper(\n",
" embedding=dec_embeddings,\n",
" start_tokens=tf.tile([go_id], [batch_size]),\n",
" end_token=eos_id)\n",
"\n",
"def seq2seq(\n",
" encoder_input,\n",
" decoder_input,\n",
" targets,\n",
" encoder_sequence_length,\n",
" decoder_sequence_length,\n",
" keep_prob,\n",
" enc_vocab_size,\n",
" dec_vocab_size,\n",
" embed_dim,\n",
" num_layers,\n",
" num_units,\n",
" go_id=SPECIAL_SYMBOLS.index(\"<GO>\"),\n",
" eos_id=SPECIAL_SYMBOLS.index(\"<EOS>\"),\n",
" is_inferring=False,\n",
" learning_rate=0.01,\n",
" max_gradients_norm=5.0):\n",
" batch_size = tf.shape(encoder_input)[0]\n",
"\n",
" with tf.variable_scope(\"embedding\"):\n",
" enc_embeddings = tf.get_variable(\"enc_embedding\",\n",
" [enc_vocab_size, embed_dim])\n",
" dec_embeddings = tf.get_variable(\"dec_embedding\",\n",
" [dec_vocab_size, embed_dim])\n",
" \n",
" encoder_emb_input = tf.nn.embedding_lookup(\n",
" enc_embeddings,\n",
" encoder_input)\n",
" encoder_emb_input = tf.nn.dropout(\n",
" encoder_emb_input,\n",
" keep_prob)\n",
" decoder_emb_input = tf.nn.embedding_lookup(\n",
" dec_embeddings,\n",
" decoder_input)\n",
" decoder_emb_input = tf.nn.dropout(\n",
" decoder_emb_input,\n",
" keep_prob)\n",
" weights = tf.sign(tf.to_float(targets))\n",
" \n",
" def _encoder_cell():\n",
" cell = tf.nn.rnn_cell.GRUCell(num_units)\n",
" # TODO output_keep_probe, state_keep_prob\n",
" cell = tf.nn.rnn_cell.DropoutWrapper(\n",
" cell,\n",
" input_keep_prob=keep_prob,\n",
" output_keep_prob=keep_prob)\n",
" return cell\n",
"\n",
" encoder_outputs, state = encode(\n",
" cell=tf.contrib.rnn.MultiRNNCell([_encoder_cell() for _ in range(num_layers)]),\n",
" encoder_emb_input=encoder_emb_input,\n",
" sequence_length=encoder_sequence_length)\n",
"\n",
" if not is_inferring:\n",
" helper = training_helper(\n",
" decoder_emb_input,\n",
" decoder_sequence_length)\n",
" else:\n",
" helper = infer_helper(\n",
" dec_embeddings,\n",
" batch_size,\n",
" go_id, eos_id)\n",
" \n",
" def _decoder_cell():\n",
" attention_mechanism = tf.contrib.seq2seq.LuongAttention(\n",
" num_units,\n",
" encoder_outputs,\n",
" encoder_sequence_length)\n",
" cell = tf.nn.rnn_cell.GRUCell(num_units)\n",
" cell = tf.contrib.seq2seq.AttentionWrapper(\n",
" cell,\n",
" attention_mechanism,\n",
" attention_layer_size=num_units // 2)\n",
" return tf.nn.rnn_cell.DropoutWrapper(\n",
" cell,\n",
" input_keep_prob=keep_prob,\n",
" output_keep_prob=keep_prob)\n",
" \n",
" attn_cell = tf.contrib.rnn.MultiRNNCell([_decoder_cell() for _ in range(num_layers)])\n",
"\n",
" state = tuple(attn_s.clone(cell_state=s)\n",
" for s, attn_s in zip(state, attn_cell.zero_state(batch_size, tf.float32)))\n",
"\n",
" (logits, sample_id), _, _ = decode(\n",
" cell=attn_cell,\n",
" helper=helper,\n",
" state=state,\n",
" dec_vocab_size=dec_vocab_size)\n",
" \n",
" loss = loss_fn(logits, targets, weights)\n",
" tf.summary.scalar(\"loss\", loss)\n",
" train_op, _ = train(loss, learning_rate=learning_rate,\n",
" max_gradients_norm=max_gradients_norm)\n",
"\n",
" return (\n",
" loss,\n",
" sample_id,\n",
" train_op,\n",
" tf.summary.merge_all())"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /tmp/work/model/model.ckpt-0\n",
"INFO:tensorflow:Starting standard services.\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/model/model.ckpt\n",
"INFO:tensorflow:Starting queue runners.\n",
"INFO:tensorflow:test_model/global_step/sec: 0\n",
"Step 10, loss 2.475779151916504\n",
"Step 20, loss 2.261063575744629\n",
"sally madly loves joe \n",
" MADLY MADLY LOVES <EOS>\n",
"Step 30, loss 2.249356508255005\n",
"Step 40, loss 1.9047452926635742\n",
"sally madly loves joe \n",
" SALLY MADLY LOVES <EOS>\n",
"Step 50, loss 1.8361172556877137\n",
"Step 60, loss 1.7297092795372009\n",
"sally madly loves joe \n",
" SALLY MADLY LOVES JOE <EOS>\n",
"Step 70, loss 1.6062021732330323\n",
"Step 80, loss 1.6975791692733764\n",
"sally madly loves joe \n",
" BILL LOVES LOVES <EOS>\n",
"Step 90, loss 1.7016813635826111\n",
"Step 100, loss 1.515825343132019\n",
"sally madly loves joe \n",
" JOE TRULY HATES JOE <EOS>\n"
]
}
],
"source": [
"tf.reset_default_graph()\n",
"\n",
"data_size = len(src_sentences)\n",
"batch_size = 10\n",
"num_epochs = 10\n",
"enc_vocab_size = len(src_processor.vocabulary_)\n",
"dec_vocab_size = len(dst_processor.vocabulary_)\n",
"embed_dim = 10\n",
"num_layers = 2\n",
"num_units = 100\n",
"learning_rate = 0.01\n",
"\n",
"def graph(\n",
" encoder_input,\n",
" encoder_sequence_length,\n",
" decoder_input,\n",
" decoder_sequence_length,\n",
" targets,\n",
" keep_prob,\n",
" is_inferring=False):\n",
" \n",
" return seq2seq(\n",
" encoder_input=encoder_input,\n",
" decoder_input=decoder_input,\n",
" targets=targets,\n",
" encoder_sequence_length=encoder_sequence_length,\n",
" decoder_sequence_length=decoder_sequence_length,\n",
" keep_prob=keep_prob,\n",
" enc_vocab_size=enc_vocab_size,\n",
" dec_vocab_size=dec_vocab_size,\n",
" embed_dim=embed_dim,\n",
" num_layers=num_layers,\n",
" num_units=num_units,\n",
" learning_rate=learning_rate,\n",
" is_inferring=is_inferring)\n",
"\n",
"keep_prob = tf.placeholder(tf.float32)\n",
"\n",
"with tf.variable_scope(\"test_model\"):\n",
" loss, _, train_op, _ = graph(**{\n",
" **inputs(batch_size, \"truely_madly_deeply.tfrecords\", num_epochs),\n",
" **{\"keep_prob\": keep_prob}})\n",
"\n",
"with tf.variable_scope(\"test_model\", reuse=True):\n",
" inferring_input_placeholders = inferring_inputs()\n",
" _, sample_id, _, _ = graph(**{\n",
" **inferring_input_placeholders,\n",
" **{\"keep_prob\": keep_prob, \"is_inferring\": True}})\n",
" \n",
"def print_sample(sess):\n",
" src_ids = list(src_processor.transform([src_sentences[0]]))\n",
" sampled = sess.run(sample_id, feed_dict={\n",
" inferring_input_placeholders[\"encoder_input\"]: src_ids,\n",
" inferring_input_placeholders[\"encoder_sequence_length\"]: [len(src_ids[0])],\n",
" keep_prob: 1.0})\n",
" print(src_sentences[0],\n",
" \" \".join(list(dst_processor.reverse([sampled[0]]))))\n",
"\n",
"sv = tf.train.Supervisor(logdir=\"/tmp/work/model\", summary_op=None)\n",
"with sv.managed_session() as sess:\n",
" total_loss = 0.0\n",
" step = 1\n",
" while True:\n",
" if sv.should_stop():\n",
" break\n",
"\n",
" _loss, _ = sess.run([loss, train_op], feed_dict={keep_prob: 0.5})\n",
"\n",
" total_loss += _loss\n",
" \n",
" if step % 10 == 0:\n",
" print(\"Step {}, loss {}\".format(step, total_loss / 10))\n",
" total_loss = 0.0\n",
" if step % 20 == 0: \n",
" print_sample(sess)\n",
"\n",
" step += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"[京都フリー翻訳タスク(KFTT)](http://www.phontron.com/kftt/index-ja.html)を使ってみる。\n",
"\n",
"量が多いからtrainファイルの頭1/2で試す。"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"164941\r\n"
]
}
],
"source": [
"! echo $(wc /tmp/work/kftt-data-1.0/data/tok/kyoto-train.cln.ja | gawk '{print $1}') \" / 2\" | bc"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"! split -l 164941 /tmp/work/kftt-data-1.0/data/tok/kyoto-train.cln.ja\n",
"! mv xaa kyoto-train.cln.harf.ja\n",
"! split -l 164941 /tmp/work/kftt-data-1.0/data/tok/kyoto-train.cln.en\n",
"! mv xaa kyoto-train.cln.harf.en"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Japanese vocabylary size: 14916\n",
"English vocabylary size: 15485\n"
]
}
],
"source": [
"TRAIN_SRC_DATA = \"/tmp/work/kyoto-train.cln.harf.ja\"\n",
"TRAIN_DST_DATA = \"/tmp/work/kyoto-train.cln.harf.en\"\n",
"TEST_SRC_DATA = \"/tmp/work/kftt-data-1.0/data/tok/kyoto-test.ja\"\n",
"TEST_DST_DATA = \"/tmp/work/kftt-data-1.0/data/tok/kyoto-test.en\"\n",
"\n",
"SRC_VOCAB = \"/tmp/work/vocab.ja.pickle\"\n",
"DST_VOCAB = \"/tmp/work/vocab.en.pickle\"\n",
"\n",
"logdir = \"/tmp/work/kftt-param\"\n",
"\n",
"max_document_length = 100\n",
"\n",
"def create_vocabulary(path):\n",
" with open(path) as f:\n",
" return NoPadVocabularyProcessor(\n",
" max_document_length=100,\n",
" min_frequency=10,\n",
" vocabulary=Vocabulary(),\n",
" tokenizer_fn=tokenizer).fit(f)\n",
"\n",
"if not os.path.exists(SRC_VOCAB):\n",
" src_processor = create_vocabulary(TRAIN_SRC_DATA)\n",
" src_processor.save(SRC_VOCAB)\n",
"else:\n",
" src_processor = NoPadVocabularyProcessor.restore(SRC_VOCAB)\n",
"\n",
"if not os.path.exists(DST_VOCAB):\n",
" dst_processor = create_vocabulary(TRAIN_DST_DATA)\n",
" dst_processor.save(DST_VOCAB)\n",
"else:\n",
" dst_processor = NoPadVocabularyProcessor.restore(DST_VOCAB)\n",
"\n",
"go_id = dst_processor.vocabulary_.get(\"<GO>\")\n",
"eos_id = dst_processor.vocabulary_.get(\"<EOS>\")\n",
"enc_vocab_size = len(src_processor.vocabulary_)\n",
"dec_vocab_size = len(dst_processor.vocabulary_)\n",
"\n",
"print(\"Japanese vocabylary size: {}\".format(enc_vocab_size))\n",
"print(\"English vocabylary size: {}\".format(dec_vocab_size))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"TRAIN_DATA_TFRECORD = \"/tmp/work/kyoto-train.cln.harf.tfrecord\"\n",
"TEST_DATA_TFRECORD = \"/tmp/work/kyoto-test.tfrecord\"\n",
"\n",
"def create_tfrecord(in_src_path, in_dst_path, out_path):\n",
" with open(in_src_path) as f1, open(in_dst_path) as f2, open(out_path, \"wb\") as f3:\n",
" src_sentence_ids_list = src_processor.transform(f1)\n",
" dst_sentence_ids_list = dst_processor.transform(f2)\n",
" writer = tf.python_io.TFRecordWriter(f3.name)\n",
" for src_sentence_ids, dst_sentence_ids in zip(src_sentence_ids_list,\\\n",
" dst_sentence_ids_list):\n",
" ex = make_example(src_sentence_ids, dst_sentence_ids)\n",
" writer.write(ex.SerializeToString())\n",
" writer.close()\n",
"\n",
"if not os.path.exists(TRAIN_DATA_TFRECORD):\n",
" create_tfrecord(TRAIN_SRC_DATA, TRAIN_DST_DATA, TRAIN_DATA_TFRECORD)\n",
"\n",
"if not os.path.exists(TEST_DATA_TFRECORD):\n",
" create_tfrecord(TEST_SRC_DATA, TEST_DST_DATA, TEST_DATA_TFRECORD)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /tmp/work/kftt-param/model.ckpt-119702\n",
"INFO:tensorflow:Starting standard services.\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"INFO:tensorflow:Starting queue runners.\n",
"Step 32000 / 164941, training loss 2.955790, test loss 3.610759, elapsed 155.859858s\n",
"Step 64000 / 164941, training loss 3.214088, test loss 3.303418, elapsed 154.051139s\n",
"Step 96000 / 164941, training loss 3.179242, test loss 3.179810, elapsed 155.419681s\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"Step 128000 / 164941, training loss 2.963753, test loss 3.454105, elapsed 155.601932s\n",
"Step 160000 / 164941, training loss 3.129966, test loss 3.428210, elapsed 156.284561s\n",
"Step 27059 / 164941, training loss 3.043962, test loss 3.286095, elapsed 155.217415s\n",
"Step 59059 / 164941, training loss 3.239690, test loss 3.404328, elapsed 154.764190s\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"Step 91059 / 164941, training loss 3.121370, test loss 3.437113, elapsed 155.464047s\n",
"Step 123059 / 164941, training loss 2.966356, test loss 3.358104, elapsed 155.541891s\n",
"Step 155059 / 164941, training loss 3.095634, test loss 3.266939, elapsed 156.200432s\n",
"Step 22118 / 164941, training loss 3.076107, test loss 3.277872, elapsed 154.762615s\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"Step 54118 / 164941, training loss 3.247306, test loss 3.299399, elapsed 155.225333s\n",
"Step 86118 / 164941, training loss 3.066488, test loss 3.385982, elapsed 155.067276s\n",
"Step 118118 / 164941, training loss 2.987248, test loss 3.351193, elapsed 155.344533s\n",
"Step 150118 / 164941, training loss 3.068154, test loss 3.274154, elapsed 155.232993s\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"Step 17177 / 164941, training loss 3.092844, test loss 3.673700, elapsed 155.766533s\n",
"Step 49177 / 164941, training loss 3.244971, test loss 3.511972, elapsed 155.381531s\n",
"Step 81177 / 164941, training loss 3.026749, test loss 3.442863, elapsed 155.906324s\n",
"Step 113177 / 164941, training loss 3.017391, test loss 3.648664, elapsed 155.843899s\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"Step 145177 / 164941, training loss 3.054037, test loss 3.330358, elapsed 155.839185s\n",
"Step 12236 / 164941, training loss 3.101636, test loss 3.361992, elapsed 154.582553s\n",
"Step 44236 / 164941, training loss 3.250136, test loss 3.398018, elapsed 155.170406s\n",
"Step 76236 / 164941, training loss 2.970292, test loss 3.257153, elapsed 154.940045s\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"Step 108236 / 164941, training loss 3.046880, test loss 3.396551, elapsed 156.018393s\n",
"Step 140236 / 164941, training loss 3.031431, test loss 3.311976, elapsed 155.616624s\n",
"Step 7295 / 164941, training loss 3.130935, test loss 3.317468, elapsed 155.316411s\n",
"Step 39295 / 164941, training loss 3.237224, test loss 3.201086, elapsed 155.159777s\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"Step 71295 / 164941, training loss 2.923450, test loss 3.576703, elapsed 155.130497s\n",
"Step 103295 / 164941, training loss 3.077371, test loss 3.190845, elapsed 156.320625s\n",
"Step 135295 / 164941, training loss 3.009074, test loss 3.435367, elapsed 155.930322s\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"Step 2354 / 164941, training loss 3.154707, test loss 3.442424, elapsed 155.266500s\n",
"Step 34354 / 164941, training loss 3.236584, test loss 3.706930, elapsed 154.725371s\n",
"Step 66354 / 164941, training loss 2.873634, test loss 3.417327, elapsed 155.217429s\n",
"Step 98354 / 164941, training loss 3.105311, test loss 3.387954, elapsed 156.290312s\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"Step 130354 / 164941, training loss 2.994145, test loss 3.139106, elapsed 156.059800s\n",
"Step 162354 / 164941, training loss 3.165069, test loss 3.503583, elapsed 154.644478s\n",
"Step 29413 / 164941, training loss 3.229005, test loss 3.447956, elapsed 154.803382s\n",
"Step 61413 / 164941, training loss 2.842835, test loss 3.323263, elapsed 155.632782s\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"Step 93413 / 164941, training loss 3.117907, test loss 3.292005, elapsed 155.991641s\n",
"Step 125413 / 164941, training loss 2.978633, test loss 3.525044, elapsed 156.009066s\n",
"Step 157413 / 164941, training loss 3.179128, test loss 3.470156, elapsed 154.802577s\n",
"Step 24472 / 164941, training loss 3.204044, test loss 3.392745, elapsed 154.570330s\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"Step 56472 / 164941, training loss 2.836560, test loss 3.612292, elapsed 155.624684s\n",
"Step 88472 / 164941, training loss 3.128556, test loss 3.367498, elapsed 155.831625s\n",
"Step 120472 / 164941, training loss 2.983065, test loss 3.390799, elapsed 155.163952s\n",
"Step 152472 / 164941, training loss 3.194086, test loss 3.197423, elapsed 154.228353s\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"Step 19531 / 164941, training loss 3.171189, test loss 3.395420, elapsed 155.323945s\n",
"Step 51531 / 164941, training loss 2.843658, test loss 3.587918, elapsed 155.822081s\n",
"Step 83531 / 164941, training loss 3.115164, test loss 3.615798, elapsed 155.848897s\n"
]
}
],
"source": [
"tf.reset_default_graph()\n",
"\n",
"batch_size = 64\n",
"#num_epochs = 20\n",
"#num_epochs = 20\n",
"#num_epochs = 10\n",
"num_epochs = 10\n",
"embed_dim = 512\n",
"num_layers = 2\n",
"num_units = 512\n",
"learning_rate = 0.001\n",
"max_gradients_norm = 1.0\n",
"init_scale = 0.04\n",
"\n",
"keep_prob = tf.placeholder(tf.float32)\n",
"\n",
"def graph(\n",
" encoder_input,\n",
" encoder_sequence_length,\n",
" decoder_input,\n",
" decoder_sequence_length,\n",
" targets,\n",
" keep_prob,\n",
" is_inferring=False):\n",
" return seq2seq(\n",
" encoder_input=encoder_input,\n",
" decoder_input=decoder_input,\n",
" targets=targets,\n",
" encoder_sequence_length=encoder_sequence_length,\n",
" decoder_sequence_length=decoder_sequence_length,\n",
" keep_prob=keep_prob,\n",
" enc_vocab_size=enc_vocab_size,\n",
" dec_vocab_size=dec_vocab_size,\n",
" embed_dim=embed_dim,\n",
" num_layers=num_layers,\n",
" num_units=num_units,\n",
" learning_rate=learning_rate,\n",
" is_inferring=is_inferring,\n",
" max_gradients_norm=max_gradients_norm)\n",
"\n",
"initializer = tf.random_uniform_initializer(init_scale, init_scale)\n",
"\n",
"with tf.name_scope(\"Train\"):\n",
" with tf.variable_scope(\"inputs\"):\n",
" training_inputs = inputs(batch_size, TRAIN_DATA_TFRECORD, num_epochs)\n",
" with tf.variable_scope(\"model\", initializer=initializer):\n",
" loss, _, train_op, training_summary_op = graph(**{\n",
" **training_inputs,\n",
" **{\"keep_prob\": keep_prob}})\n",
"\n",
"with tf.name_scope(\"Test\"):\n",
" with tf.variable_scope(\"inputs\"): # inputsはreuse=None\n",
" test_inputs = inputs(batch_size, TEST_DATA_TFRECORD, None)\n",
" with tf.variable_scope(\"model\", reuse=True, initializer=initializer): # modelは当然reuse=True\n",
" test_loss, _, _, testing_summary_op = graph(**{\n",
" **test_inputs,\n",
" **{\"keep_prob\": keep_prob}})\n",
"\n",
"with tf.name_scope(\"Inferring\"):\n",
" with tf.variable_scope(\"inputs\"):\n",
" inferring_input_placeholders = inferring_inputs()\n",
" with tf.variable_scope(\"model\", reuse=True, initializer=initializer):\n",
" _, sample_id, _, _ = graph(**{\n",
" **inferring_input_placeholders,\n",
" **{\"keep_prob\": keep_prob, \"is_inferring\": True}})\n",
"\n",
"sv = tf.train.Supervisor(\n",
" logdir=logdir,\n",
" summary_op=None,\n",
" global_step=tf.train.get_global_step())\n",
"\n",
"with sv.managed_session() as sess:\n",
" start_time = time.time()\n",
" total_loss = 0.0\n",
" step = 1\n",
" while True:\n",
" if sv.should_stop():\n",
" break\n",
"\n",
" _loss, _, training_summary = sess.run([loss, train_op, training_summary_op],\n",
" feed_dict={keep_prob: 0.5})\n",
" \n",
" if np.isnan(_loss):\n",
" raise RuntimeError(\"loss is nan\")\n",
" \n",
" sv.summary_computed(sess, training_summary)\n",
" \n",
" total_loss += _loss\n",
" \n",
" if step % 20 == 0:\n",
" _test_loss, testing_summary = sess.run([test_loss, testing_summary_op],\n",
" feed_dict={keep_prob: 1.0})\n",
" sv.summary_computed(sess, testing_summary)\n",
"\n",
" if step % 500 == 0:\n",
" print(\"Step {} / 164941, training loss {:f}, test loss {:f}, elapsed {:f}s\".format(\n",
" (step * batch_size) % 164941, total_loss / 500, _test_loss, time.time() - start_time))\n",
" start_time = time.time()\n",
" total_loss = 0.0\n",
"\n",
" step += 1"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /tmp/work/kftt-param/model.ckpt-142870\n",
"INFO:tensorflow:Starting standard services.\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"INFO:tensorflow:Starting queue runners.\n",
"日本 の 水墨 画 を 一変 さ せ た 。\n",
" Truth: He revolutionized the Japanese ink painting .\n",
" Predicted: He changed the ink painting in Japan . <EOS>\n",
"\n",
"諱 は 「 等楊 ( とうよう ) 」 、 もしくは 「 拙宗 ( せっしゅう ) 」 と 号 し た 。\n",
" Truth: He was given the posthumous name \" Toyo \" or \" Sesshu ( 拙宗 ) . \"\n",
" Predicted: His imina ( personal name ) was <UNK> and <UNK> , and he was called ' <UNK> . ' <EOS>\n",
"\n",
"備中 国 に 生まれ 、 京都 ・ 相国 寺 に 入 っ て から 周防 国 に 移 る 。\n",
" Truth: Born in Bicchu Province , he moved to Suo Province after entering SShokoku-ji Temple in Kyoto .\n",
" Predicted: He was born in Bicchu Province , and moved to Suo Province in Kyoto . <EOS>\n",
"\n",
"その 後 遣明 使 に 随行 し て 中国 ( 明 ) に 渡 っ て 中国 の 水墨 画 を 学 ん だ 。\n",
" Truth: Later he accompanied a mission to Ming Dynasty China and learned Chinese ink painting .\n",
" Predicted: Later , he went to China and studied the Chinese ink painting in China . <EOS>\n",
"\n",
"作品 は 数 多 く 、 中国 風 の 山水 画 だけ で な く 人物 画 や 花鳥 画 も よ く し た 。\n",
" Truth: His works were many , including not only Chinese-style landscape paintings , but also portraits and pictures of flowers and birds .\n",
" Predicted: Many works were not produced , and they were painted in Chinese paintings and paintings of paintings and paintings . <EOS>\n",
"\n",
"大胆 な 構図 と 力強 い 筆線 は 非常 に 個性 的 な 画風 を 作り出 し て い る 。\n",
" Truth: His bold compositions and strong brush strokes constituted an extremely distinctive style .\n",
" Predicted: The <UNK> and bold bold bold style were characterized by <UNK> . <EOS>\n",
"\n",
"現存 する 作品 の うち 6 点 が 国宝 に 指定 さ れ て お り 、 日本 の 画家 の なか で も 別格 の 評価 を 受け て い る と いえ る 。\n",
" Truth: 6 of his extant works are designated national treasures . Indeed , he is considered to be extraordinary among Japanese painters .\n",
" Predicted: Among the existing works of existing works , the extant works are designated as national treasures , and it is recognized as a national treasure . <EOS>\n",
"\n",
"この ため 、 花鳥 図 屏風 など に 「 伝 雪舟 筆 」 さ れ る 作品 は 大変 多 い 。\n",
" Truth: For this reason , there are a great many artworks that are attributed to him , such as folding screens with pictures of flowers and that birds are painted on them .\n",
" Predicted: Therefore , many works of Sesshu are often known as ' Sesshu ' ( <UNK> painting ) and <UNK> ( folding screen ) . <EOS>\n",
"\n",
"真筆 で あ る か 専門 家 の 間 で も 意見 の 分かれ る もの も 多々 あ る 。\n",
" Truth: There are many works that even experts cannot agree if they are really his work or not .\n",
" Predicted: There are many opinions that <UNK> the <UNK> of the <UNK> . <EOS>\n",
"\n",
"弟子 に 、 秋月 、 宗 淵 、 等 春 ら が い る 。\n",
" Truth: His disciples include Shugetsu , Soen , and Toshun .\n",
" Predicted: His disciples include <UNK> , <UNK> , <UNK> , and <UNK> . <EOS>\n",
"\n"
]
}
],
"source": [
"def print_sample(\n",
" sess,\n",
" sample_id,\n",
" inferring_input_placeholders,\n",
" src_data_path,\n",
" dst_data_path,\n",
" n=10):\n",
" with open(src_data_path) as f1, open(dst_data_path) as f2:\n",
" for _ in range(n):\n",
" src_sentence = f1.readline()\n",
" dst_sentence = f2.readline()\n",
" src_ids = list(src_processor.transform([src_sentence]))\n",
" sampled = sess.run(sample_id, feed_dict={\n",
" inferring_input_placeholders[\"encoder_input\"]: src_ids,\n",
" inferring_input_placeholders[\"encoder_sequence_length\"]: [len(src_ids[0])],\n",
" keep_prob: 1.0\n",
" })\n",
" print(src_sentence,\n",
" \"Truth: {}\".format(dst_sentence),\n",
" \"Predicted: {}\".format(\" \".join(list(dst_processor.reverse([sampled[0]])))))\n",
" print()\n",
"\n",
"with sv.managed_session() as sess:\n",
" print_sample(\n",
" sess,\n",
" sample_id,\n",
" inferring_input_placeholders,\n",
" TRAIN_SRC_DATA,\n",
" TRAIN_DST_DATA)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /tmp/work/kftt-param/model.ckpt-142870\n",
"INFO:tensorflow:Starting standard services.\n",
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n",
"INFO:tensorflow:Starting queue runners.\n",
"Infobox Buddhist\n",
" Truth: Infobox Buddhist\n",
" Predicted: <UNK> <UNK> <EOS>\n",
"\n",
"道元 ( どうげん ) は 、 鎌倉 時代 初期 の 禅僧 。\n",
" Truth: Dogen was a Zen monk in the early Kamakura period .\n",
" Predicted: Dogen was a priest in the Kamakura period . <EOS>\n",
"\n",
"曹洞 宗 の 開祖 。\n",
" Truth: The founder of Soto Zen\n",
" Predicted: He was the founder of the Soto sect . <EOS>\n",
"\n",
"晩年 に 希 玄 と い う 異称 も 用い た 。\n",
" Truth: Later in his life he also went by the name Kigen .\n",
" Predicted: In the later years , he was also called <UNK> . <EOS>\n",
"\n",
"同宗旨 で は 高祖 と 尊称 さ れ る 。\n",
" Truth: Within the sect he is referred to by the honorary title Koso .\n",
" Predicted: In the <UNK> , the title of the title is given the title of the title of the title . <EOS>\n",
"\n",
"諡 は 、 仏性 伝 東 国師 、 承陽 大師 _ ( 僧 ) 。\n",
" Truth: Posthumously named Bussho Dento Kokushi , or Joyo-Daishi .\n",
" Predicted: His shigo ( posthumous name ) was <UNK> <UNK> , <UNK> ( <UNK> ) , and <UNK> ( a Buddhist priest ) . <EOS>\n",
"\n",
"一般 に は 道元 禅師 と 呼 ば れ る 。\n",
" Truth: He is generally called Dogen Zenji .\n",
" Predicted: He is generally called Dogen Zenji . <EOS>\n",
"\n",
"日本 に 歯磨き 洗面 、 食事 の 際 の 作法 や 掃除 の 習慣 を 広め た と い わ れ る 。\n",
" Truth: He is reputed to have been the one that spread the practices of tooth brushing , face washing , table manners and cleaning in Japan .\n",
" Predicted: It is said that the custom of <UNK> and <UNK> was spread in Japan and the manners of eating <UNK> and manners . <EOS>\n",
"\n",
"最初 に モウ ソウチク ( 孟宗 竹 ) を 持ち帰 っ た と する 説 も あ る 。\n",
" Truth: Another story has it that he was the first one to bring Moso-chiku ( Moso bamboo ) to Japan .\n",
" Predicted: There is a theory that the first <UNK> was <UNK> <UNK> ( <UNK> <UNK> ) . <EOS>\n",
"\n",
"道元 の 出生 に は 不明 の 点 が 多 い が 、 内 大臣 土御門 通親 ( 源 通親 あるいは 久我 通親 ) の 嫡流 に 生まれ た と する 点 で は 諸説 が 一致 し て い る 。\n",
" Truth: Though some points are unclear about Dogen 's birth , all accounts agree that he was born in the line of Udaijin ( Minister of the Right ) Michichika TSUCHIMIKADO ( MINAMOTO no Michichika or Michichika KOGA ) .\n",
" Predicted: There are many theories that the origin of Dogen was unknown , and the theory of the <UNK> was <UNK> by the <UNK> ( Michichika TSUCHIMIKADO ) , and the <UNK> of the <UNK> . <EOS>\n",
"\n"
]
}
],
"source": [
"with sv.managed_session() as sess:\n",
" print_sample(\n",
" sess,\n",
" sample_id,\n",
" inferring_input_placeholders,\n",
" TEST_SRC_DATA,\n",
" TEST_DST_DATA)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"65013 / 3436"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"15000 / 3436"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment