Last active
September 8, 2017 22:53
-
-
Save p-baleine/6a710a591549e66b1146d182e3baeef9 to your computer and use it in GitHub Desktop.
seq2seq new API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import os\n", | |
"import six\n", | |
"import tensorflow as tf\n", | |
"import time\n", | |
"\n", | |
"from tensorflow.python.layers import core as core_layers\n", | |
"from tensorflow.python.util import nest" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"SPECIAL_SYMBOLS = [\"<PAD>\", \"<GO>\", \"<EOS>\", \"<UNK>\"]\n", | |
"\n", | |
"def tokenizer(raw_sentences):\n", | |
" for sentence in raw_sentences:\n", | |
" yield sentence.split()\n", | |
"\n", | |
"class Vocabulary(tf.contrib.learn.preprocessing.CategoricalVocabulary):\n", | |
" def __init__(self):\n", | |
" super(Vocabulary, self).__init__(unknown_token=\"<UNK>\")\n", | |
" self._mapping = dict((w, id) for id, w in enumerate(SPECIAL_SYMBOLS))\n", | |
" self._reverse_mapping = SPECIAL_SYMBOLS[:]\n", | |
" \n", | |
" def get(self, category):\n", | |
" if category not in self._mapping and self._freeze:\n", | |
" return SPECIAL_SYMBOLS.index(\"<UNK>\")\n", | |
" else:\n", | |
" return super(Vocabulary, self).get(category)\n", | |
"\n", | |
" def trim(self, min_frequency, max_frequency=-1):\n", | |
" self._freq = sorted(\n", | |
" sorted(\n", | |
" six.iteritems(self._freq),\n", | |
" key=lambda x: (isinstance(x[0], str), x[0])),\n", | |
" key=lambda x: x[1],\n", | |
" reverse=True)\n", | |
" self._mapping = dict((w, id) for id, w in enumerate(SPECIAL_SYMBOLS))\n", | |
" self._reverse_mapping = SPECIAL_SYMBOLS[:]\n", | |
" idx = 4 # after special symbols\n", | |
" for category, count in self._freq:\n", | |
" if max_frequency > 0 and count >= max_frequency:\n", | |
" continue\n", | |
" if count <= min_frequency:\n", | |
" break\n", | |
" self._mapping[category] = idx\n", | |
" idx += 1\n", | |
" self._reverse_mapping.append(category)\n", | |
" self._freq = dict(self._freq[:idx - 1])\n", | |
" \n", | |
"class NoPadVocabularyProcessor(tf.contrib.learn.preprocessing.VocabularyProcessor):\n", | |
" \"\"\"パディングしない版\"\"\"\n", | |
" def transform(self, raw_documents):\n", | |
" for tokens in self._tokenizer(raw_documents):\n", | |
" word_ids = np.zeros(len(tokens), np.int64)\n", | |
" for idx, token in enumerate(tokens):\n", | |
" if idx >= self.max_document_length:\n", | |
" break\n", | |
" word_ids[idx] = self.vocabulary_.get(token)\n", | |
" yield word_ids" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"合成文章([ここ](http://www.cems.uwe.ac.uk/~cjwallac/apps/tools/bnf/bnf2sent.cgi)で作った)で試す。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"sally madly loves joe \r\n", | |
"pat loves joe \r\n", | |
"fred hates sally \r\n" | |
] | |
} | |
], | |
"source": [ | |
"! head -n 3 truely_madly_deeply.txt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"SALLY LOVES BILL \r\n", | |
"FRED TRULY DEEPLY MADLY DEEPLY LOVES BILL \r\n", | |
"PAT TRULY LOVES FRED \r\n" | |
] | |
} | |
], | |
"source": [ | |
"# dstは大文字版、辞書のIDが異なるものに成るようにシャッフルしておく\n", | |
"! cat truely_madly_deeply.txt | awk '{print toupper($0)}' | shuf > truely_madly_deeply_upper.txt\n", | |
"! head -n 3 truely_madly_deeply_upper.txt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"with open(\"truely_madly_deeply.txt\") as f1,\\\n", | |
"open(\"truely_madly_deeply_upper.txt\") as f2:\n", | |
" src_processor = NoPadVocabularyProcessor(\n", | |
" max_document_length=100,\n", | |
" vocabulary=Vocabulary(),\n", | |
" tokenizer_fn=tokenizer).fit(f1)\n", | |
" dst_processor = NoPadVocabularyProcessor(\n", | |
" max_document_length=100,\n", | |
" vocabulary=Vocabulary(),\n", | |
" tokenizer_fn=tokenizer).fit(f2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['sally madly loves joe', 'pat loves joe']\n", | |
"['SALLY MADLY LOVES JOE', 'PAT LOVES JOE']\n" | |
] | |
} | |
], | |
"source": [ | |
"print(list(src_processor.reverse(\n", | |
" src_processor.transform([\"sally madly loves joe\", \"pat loves joe\"]))))\n", | |
"print(list(dst_processor.reverse(\n", | |
" dst_processor.transform([x.upper() for x in [\"sally madly loves joe\", \"pat loves joe\"]]))))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def make_example(src_sentence_ids, dst_sentence_ids):\n", | |
" \"\"\"seq2seq用のExampleを作って返す\"\"\"\n", | |
" ex = tf.train.SequenceExample()\n", | |
" \n", | |
" target_ids = np.zeros(dst_sentence_ids.size + 1, dtype=np.int64)\n", | |
" target_ids[:-1] = dst_sentence_ids\n", | |
" target_ids[-1] = SPECIAL_SYMBOLS.index(\"<EOS>\")\n", | |
" _dst_sentence_ids = np.zeros(dst_sentence_ids.size + 1, dtype=np.int64)\n", | |
" _dst_sentence_ids[0] = SPECIAL_SYMBOLS.index(\"<GO>\")\n", | |
" _dst_sentence_ids[1:] = dst_sentence_ids\n", | |
" \n", | |
" ex.context.feature[\"src_sequence_length\"].int64_list.value.append(len(src_sentence_ids))\n", | |
" ex.context.feature[\"dst_sequence_length\"].int64_list.value.append(len(_dst_sentence_ids))\n", | |
"\n", | |
" fl_src_sentence_ids = ex.feature_lists.feature_list[\"src_sentence_ids\"]\n", | |
" fl_dst_sentence_ids = ex.feature_lists.feature_list[\"dst_sentence_ids\"]\n", | |
" fl_target_ids = ex.feature_lists.feature_list[\"target_ids\"]\n", | |
" \n", | |
" for id in src_sentence_ids:\n", | |
" fl_src_sentence_ids.feature.add().int64_list.value.append(id)\n", | |
" for id in _dst_sentence_ids:\n", | |
" fl_dst_sentence_ids.feature.add().int64_list.value.append(id)\n", | |
" for id in target_ids:\n", | |
" fl_target_ids.feature.add().int64_list.value.append(id)\n", | |
" \n", | |
" return ex\n", | |
"\n", | |
"def example_features():\n", | |
" \"\"\"seq2seqのfeature\"\"\"\n", | |
" context_features = {\n", | |
" \"src_sequence_length\": tf.FixedLenFeature([], dtype=tf.int64),\n", | |
" \"dst_sequence_length\": tf.FixedLenFeature([], dtype=tf.int64),\n", | |
" }\n", | |
"\n", | |
" sequence_features = {\n", | |
" \"src_sentence_ids\": tf.FixedLenSequenceFeature([], dtype=tf.int64),\n", | |
" \"dst_sentence_ids\": tf.FixedLenSequenceFeature([], dtype=tf.int64),\n", | |
" \"target_ids\": tf.FixedLenSequenceFeature([], dtype=tf.int64),\n", | |
" }\n", | |
"\n", | |
" return context_features, sequence_features" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"with open(\"truely_madly_deeply.txt\") as f1,\\\n", | |
"open(\"truely_madly_deeply.tfrecords\", \"wb\") as f2:\n", | |
" src_sentences = f1.readlines()\n", | |
" dst_sentences = [x.upper() for x in src_sentences] # srcと1対1で対応\n", | |
" src_sentence_ids_list = list(src_processor.transform(src_sentences))\n", | |
" dst_sentence_ids_list = list(dst_processor.transform(dst_sentences))\n", | |
" writer = tf.python_io.TFRecordWriter(f2.name)\n", | |
" for src_sentence_ids, dst_sentence_ids in zip(src_sentence_ids_list,\\\n", | |
" dst_sentence_ids_list):\n", | |
" ex = make_example(src_sentence_ids, dst_sentence_ids)\n", | |
" writer.write(ex.SerializeToString())\n", | |
" writer.close()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def inputs(batch_size, filename, num_epochs):\n", | |
" \"\"\"`batch_size`のインプットを返すinput pipeline\"\"\"\n", | |
" context_features, sequence_features = example_features()\n", | |
"\n", | |
" filename_queue = tf.train.string_input_producer([filename],\n", | |
" num_epochs=num_epochs)\n", | |
" reader = tf.TFRecordReader()\n", | |
" _, serialized_ex = reader.read(filename_queue)\n", | |
" context_parsed, sequence_parsed = tf.parse_single_sequence_example(\n", | |
" serialized=serialized_ex,\n", | |
" context_features=context_features,\n", | |
" sequence_features=sequence_features)\n", | |
" \n", | |
" # tf.train.shuffle_batchがdynamic_padをサポートしていないので\n", | |
" # 自前でshuffleする\n", | |
" # https://github.com/tensorflow/tensorflow/issues/5147#issuecomment-271086206\n", | |
" min_after_dequeue = 10000\n", | |
" capacity = min_after_dequeue + 3 * batch_size\n", | |
" inputs = {\n", | |
" **{\"encoder_input\": sequence_parsed[\"src_sentence_ids\"],\n", | |
" \"decoder_input\": sequence_parsed[\"dst_sentence_ids\"],\n", | |
" \"targets\": sequence_parsed[\"target_ids\"]},\n", | |
" **{\"encoder_sequence_length\": context_parsed[\"src_sequence_length\"],\n", | |
" \"decoder_sequence_length\": context_parsed[\"dst_sequence_length\"]}}\n", | |
" dtypes = [x.dtype for _, x in inputs.items()]\n", | |
" shapes = [x.get_shape() for _, x in inputs.items()]\n", | |
" queue = tf.RandomShuffleQueue(capacity, min_after_dequeue, dtypes,\n", | |
" names=list(inputs.keys()))\n", | |
" enqueue_op = queue.enqueue(inputs)\n", | |
" qr = tf.train.QueueRunner(queue, [enqueue_op])\n", | |
" tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, qr)\n", | |
" inputs = queue.dequeue()\n", | |
" for (name, tensor), shape in zip(inputs.items(), shapes):\n", | |
" tensor.set_shape(shape)\n", | |
" \n", | |
" batched = tf.train.batch(\n", | |
" tensors=inputs,\n", | |
" batch_size=batch_size,\n", | |
" dynamic_pad=True)\n", | |
" \n", | |
" return dict((k, tf.to_int32(t)) for k, t in batched.items())\n", | |
"\n", | |
"def inferring_inputs():\n", | |
" return {\n", | |
" \"encoder_input\": tf.placeholder(tf.int32, [None, None],\n", | |
" name=\"encoder_input\"),\n", | |
" \"encoder_sequence_length\": tf.placeholder(tf.int32, [None],\n", | |
" name=\"encoder_sequence_length\"),\n", | |
" \"decoder_input\": tf.placeholder(tf.int32, [None, None],\n", | |
" name=\"decoder_input\"),\n", | |
" \"decoder_sequence_length\": tf.placeholder(tf.int32, [None],\n", | |
" name=\"decoder_sequence_length\"),\n", | |
" \"targets\": tf.placeholder(tf.int32, [None, None],\n", | |
" name=\"targets\")\n", | |
" }" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def encode(cell, encoder_emb_input, sequence_length):\n", | |
" return tf.nn.dynamic_rnn(\n", | |
" cell,\n", | |
" encoder_emb_input,\n", | |
" sequence_length,\n", | |
" dtype=tf.float32)\n", | |
"\n", | |
"def decode(cell, helper, state, dec_vocab_size):\n", | |
" output_layer = core_layers.Dense(dec_vocab_size)\n", | |
" decoder = tf.contrib.seq2seq.BasicDecoder(\n", | |
" cell=cell,\n", | |
" helper=helper,\n", | |
" initial_state=state,\n", | |
" output_layer=output_layer)\n", | |
" return tf.contrib.seq2seq.dynamic_decode(\n", | |
" decoder=decoder,\n", | |
" impute_finished=True,\n", | |
" maximum_iterations=150)\n", | |
"\n", | |
"def loss_fn(logits, targets, weights):\n", | |
" return tf.contrib.seq2seq.sequence_loss(logits, targets, weights)\n", | |
"\n", | |
"def train(loss, learning_rate=0.01, max_gradients_norm=5.0):\n", | |
" global_step = tf.train.get_or_create_global_step()\n", | |
" optimizer = tf.train.AdamOptimizer(learning_rate)\n", | |
" params = tf.trainable_variables()\n", | |
" gradients = tf.gradients(loss, params)\n", | |
" clipped_gradients, _ = tf.clip_by_global_norm(gradients, max_gradients_norm)\n", | |
" update = optimizer.apply_gradients(\n", | |
" zip(clipped_gradients, params),\n", | |
" global_step=global_step)\n", | |
" return update, global_step\n", | |
"\n", | |
"def training_helper(decoder_emb_input, decoder_sequence_length):\n", | |
" return tf.contrib.seq2seq.TrainingHelper(\n", | |
" decoder_emb_input,\n", | |
" decoder_sequence_length)\n", | |
"\n", | |
"def infer_helper(dec_embeddings, batch_size, go_id, eos_id):\n", | |
" return tf.contrib.seq2seq.GreedyEmbeddingHelper(\n", | |
" embedding=dec_embeddings,\n", | |
" start_tokens=tf.tile([go_id], [batch_size]),\n", | |
" end_token=eos_id)\n", | |
"\n", | |
"def seq2seq(\n", | |
" encoder_input,\n", | |
" decoder_input,\n", | |
" targets,\n", | |
" encoder_sequence_length,\n", | |
" decoder_sequence_length,\n", | |
" keep_prob,\n", | |
" enc_vocab_size,\n", | |
" dec_vocab_size,\n", | |
" embed_dim,\n", | |
" num_layers,\n", | |
" num_units,\n", | |
" go_id=SPECIAL_SYMBOLS.index(\"<GO>\"),\n", | |
" eos_id=SPECIAL_SYMBOLS.index(\"<EOS>\"),\n", | |
" is_inferring=False,\n", | |
" learning_rate=0.01,\n", | |
" max_gradients_norm=5.0):\n", | |
" batch_size = tf.shape(encoder_input)[0]\n", | |
"\n", | |
" with tf.variable_scope(\"embedding\"):\n", | |
" enc_embeddings = tf.get_variable(\"enc_embedding\",\n", | |
" [enc_vocab_size, embed_dim])\n", | |
" dec_embeddings = tf.get_variable(\"dec_embedding\",\n", | |
" [dec_vocab_size, embed_dim])\n", | |
" \n", | |
" encoder_emb_input = tf.nn.embedding_lookup(\n", | |
" enc_embeddings,\n", | |
" encoder_input)\n", | |
" encoder_emb_input = tf.nn.dropout(\n", | |
" encoder_emb_input,\n", | |
" keep_prob)\n", | |
" decoder_emb_input = tf.nn.embedding_lookup(\n", | |
" dec_embeddings,\n", | |
" decoder_input)\n", | |
" decoder_emb_input = tf.nn.dropout(\n", | |
" decoder_emb_input,\n", | |
" keep_prob)\n", | |
" weights = tf.sign(tf.to_float(targets))\n", | |
" \n", | |
" def _encoder_cell():\n", | |
" cell = tf.nn.rnn_cell.GRUCell(num_units)\n", | |
" # TODO output_keep_probe, state_keep_prob\n", | |
" cell = tf.nn.rnn_cell.DropoutWrapper(\n", | |
" cell,\n", | |
" input_keep_prob=keep_prob,\n", | |
" output_keep_prob=keep_prob)\n", | |
" return cell\n", | |
"\n", | |
" encoder_outputs, state = encode(\n", | |
" cell=tf.contrib.rnn.MultiRNNCell([_encoder_cell() for _ in range(num_layers)]),\n", | |
" encoder_emb_input=encoder_emb_input,\n", | |
" sequence_length=encoder_sequence_length)\n", | |
"\n", | |
" if not is_inferring:\n", | |
" helper = training_helper(\n", | |
" decoder_emb_input,\n", | |
" decoder_sequence_length)\n", | |
" else:\n", | |
" helper = infer_helper(\n", | |
" dec_embeddings,\n", | |
" batch_size,\n", | |
" go_id, eos_id)\n", | |
" \n", | |
" def _decoder_cell():\n", | |
" attention_mechanism = tf.contrib.seq2seq.LuongAttention(\n", | |
" num_units,\n", | |
" encoder_outputs,\n", | |
" encoder_sequence_length)\n", | |
" cell = tf.nn.rnn_cell.GRUCell(num_units)\n", | |
" cell = tf.contrib.seq2seq.AttentionWrapper(\n", | |
" cell,\n", | |
" attention_mechanism,\n", | |
" attention_layer_size=num_units // 2)\n", | |
" return tf.nn.rnn_cell.DropoutWrapper(\n", | |
" cell,\n", | |
" input_keep_prob=keep_prob,\n", | |
" output_keep_prob=keep_prob)\n", | |
" \n", | |
" attn_cell = tf.contrib.rnn.MultiRNNCell([_decoder_cell() for _ in range(num_layers)])\n", | |
"\n", | |
" state = tuple(attn_s.clone(cell_state=s)\n", | |
" for s, attn_s in zip(state, attn_cell.zero_state(batch_size, tf.float32)))\n", | |
"\n", | |
" (logits, sample_id), _, _ = decode(\n", | |
" cell=attn_cell,\n", | |
" helper=helper,\n", | |
" state=state,\n", | |
" dec_vocab_size=dec_vocab_size)\n", | |
" \n", | |
" loss = loss_fn(logits, targets, weights)\n", | |
" tf.summary.scalar(\"loss\", loss)\n", | |
" train_op, _ = train(loss, learning_rate=learning_rate,\n", | |
" max_gradients_norm=max_gradients_norm)\n", | |
"\n", | |
" return (\n", | |
" loss,\n", | |
" sample_id,\n", | |
" train_op,\n", | |
" tf.summary.merge_all())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Restoring parameters from /tmp/work/model/model.ckpt-0\n", | |
"INFO:tensorflow:Starting standard services.\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/model/model.ckpt\n", | |
"INFO:tensorflow:Starting queue runners.\n", | |
"INFO:tensorflow:test_model/global_step/sec: 0\n", | |
"Step 10, loss 2.475779151916504\n", | |
"Step 20, loss 2.261063575744629\n", | |
"sally madly loves joe \n", | |
" MADLY MADLY LOVES <EOS>\n", | |
"Step 30, loss 2.249356508255005\n", | |
"Step 40, loss 1.9047452926635742\n", | |
"sally madly loves joe \n", | |
" SALLY MADLY LOVES <EOS>\n", | |
"Step 50, loss 1.8361172556877137\n", | |
"Step 60, loss 1.7297092795372009\n", | |
"sally madly loves joe \n", | |
" SALLY MADLY LOVES JOE <EOS>\n", | |
"Step 70, loss 1.6062021732330323\n", | |
"Step 80, loss 1.6975791692733764\n", | |
"sally madly loves joe \n", | |
" BILL LOVES LOVES <EOS>\n", | |
"Step 90, loss 1.7016813635826111\n", | |
"Step 100, loss 1.515825343132019\n", | |
"sally madly loves joe \n", | |
" JOE TRULY HATES JOE <EOS>\n" | |
] | |
} | |
], | |
"source": [ | |
"tf.reset_default_graph()\n", | |
"\n", | |
"data_size = len(src_sentences)\n", | |
"batch_size = 10\n", | |
"num_epochs = 10\n", | |
"enc_vocab_size = len(src_processor.vocabulary_)\n", | |
"dec_vocab_size = len(dst_processor.vocabulary_)\n", | |
"embed_dim = 10\n", | |
"num_layers = 2\n", | |
"num_units = 100\n", | |
"learning_rate = 0.01\n", | |
"\n", | |
"def graph(\n", | |
" encoder_input,\n", | |
" encoder_sequence_length,\n", | |
" decoder_input,\n", | |
" decoder_sequence_length,\n", | |
" targets,\n", | |
" keep_prob,\n", | |
" is_inferring=False):\n", | |
" \n", | |
" return seq2seq(\n", | |
" encoder_input=encoder_input,\n", | |
" decoder_input=decoder_input,\n", | |
" targets=targets,\n", | |
" encoder_sequence_length=encoder_sequence_length,\n", | |
" decoder_sequence_length=decoder_sequence_length,\n", | |
" keep_prob=keep_prob,\n", | |
" enc_vocab_size=enc_vocab_size,\n", | |
" dec_vocab_size=dec_vocab_size,\n", | |
" embed_dim=embed_dim,\n", | |
" num_layers=num_layers,\n", | |
" num_units=num_units,\n", | |
" learning_rate=learning_rate,\n", | |
" is_inferring=is_inferring)\n", | |
"\n", | |
"keep_prob = tf.placeholder(tf.float32)\n", | |
"\n", | |
"with tf.variable_scope(\"test_model\"):\n", | |
" loss, _, train_op, _ = graph(**{\n", | |
" **inputs(batch_size, \"truely_madly_deeply.tfrecords\", num_epochs),\n", | |
" **{\"keep_prob\": keep_prob}})\n", | |
"\n", | |
"with tf.variable_scope(\"test_model\", reuse=True):\n", | |
" inferring_input_placeholders = inferring_inputs()\n", | |
" _, sample_id, _, _ = graph(**{\n", | |
" **inferring_input_placeholders,\n", | |
" **{\"keep_prob\": keep_prob, \"is_inferring\": True}})\n", | |
" \n", | |
"def print_sample(sess):\n", | |
" src_ids = list(src_processor.transform([src_sentences[0]]))\n", | |
" sampled = sess.run(sample_id, feed_dict={\n", | |
" inferring_input_placeholders[\"encoder_input\"]: src_ids,\n", | |
" inferring_input_placeholders[\"encoder_sequence_length\"]: [len(src_ids[0])],\n", | |
" keep_prob: 1.0})\n", | |
" print(src_sentences[0],\n", | |
" \" \".join(list(dst_processor.reverse([sampled[0]]))))\n", | |
"\n", | |
"sv = tf.train.Supervisor(logdir=\"/tmp/work/model\", summary_op=None)\n", | |
"with sv.managed_session() as sess:\n", | |
" total_loss = 0.0\n", | |
" step = 1\n", | |
" while True:\n", | |
" if sv.should_stop():\n", | |
" break\n", | |
"\n", | |
" _loss, _ = sess.run([loss, train_op], feed_dict={keep_prob: 0.5})\n", | |
"\n", | |
" total_loss += _loss\n", | |
" \n", | |
" if step % 10 == 0:\n", | |
" print(\"Step {}, loss {}\".format(step, total_loss / 10))\n", | |
" total_loss = 0.0\n", | |
" if step % 20 == 0: \n", | |
" print_sample(sess)\n", | |
"\n", | |
" step += 1" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"---\n", | |
"\n", | |
"[京都フリー翻訳タスク(KFTT)](http://www.phontron.com/kftt/index-ja.html)を使ってみる。\n", | |
"\n", | |
"量が多いからtrainファイルの頭1/2で試す。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"164941\r\n" | |
] | |
} | |
], | |
"source": [ | |
"! echo $(wc /tmp/work/kftt-data-1.0/data/tok/kyoto-train.cln.ja | gawk '{print $1}') \" / 2\" | bc" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"! split -l 164941 /tmp/work/kftt-data-1.0/data/tok/kyoto-train.cln.ja\n", | |
"! mv xaa kyoto-train.cln.harf.ja\n", | |
"! split -l 164941 /tmp/work/kftt-data-1.0/data/tok/kyoto-train.cln.en\n", | |
"! mv xaa kyoto-train.cln.harf.en" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Japanese vocabylary size: 14916\n", | |
"English vocabylary size: 15485\n" | |
] | |
} | |
], | |
"source": [ | |
"TRAIN_SRC_DATA = \"/tmp/work/kyoto-train.cln.harf.ja\"\n", | |
"TRAIN_DST_DATA = \"/tmp/work/kyoto-train.cln.harf.en\"\n", | |
"TEST_SRC_DATA = \"/tmp/work/kftt-data-1.0/data/tok/kyoto-test.ja\"\n", | |
"TEST_DST_DATA = \"/tmp/work/kftt-data-1.0/data/tok/kyoto-test.en\"\n", | |
"\n", | |
"SRC_VOCAB = \"/tmp/work/vocab.ja.pickle\"\n", | |
"DST_VOCAB = \"/tmp/work/vocab.en.pickle\"\n", | |
"\n", | |
"logdir = \"/tmp/work/kftt-param\"\n", | |
"\n", | |
"max_document_length = 100\n", | |
"\n", | |
"def create_vocabulary(path):\n", | |
" with open(path) as f:\n", | |
" return NoPadVocabularyProcessor(\n", | |
" max_document_length=100,\n", | |
" min_frequency=10,\n", | |
" vocabulary=Vocabulary(),\n", | |
" tokenizer_fn=tokenizer).fit(f)\n", | |
"\n", | |
"if not os.path.exists(SRC_VOCAB):\n", | |
" src_processor = create_vocabulary(TRAIN_SRC_DATA)\n", | |
" src_processor.save(SRC_VOCAB)\n", | |
"else:\n", | |
" src_processor = NoPadVocabularyProcessor.restore(SRC_VOCAB)\n", | |
"\n", | |
"if not os.path.exists(DST_VOCAB):\n", | |
" dst_processor = create_vocabulary(TRAIN_DST_DATA)\n", | |
" dst_processor.save(DST_VOCAB)\n", | |
"else:\n", | |
" dst_processor = NoPadVocabularyProcessor.restore(DST_VOCAB)\n", | |
"\n", | |
"go_id = dst_processor.vocabulary_.get(\"<GO>\")\n", | |
"eos_id = dst_processor.vocabulary_.get(\"<EOS>\")\n", | |
"enc_vocab_size = len(src_processor.vocabulary_)\n", | |
"dec_vocab_size = len(dst_processor.vocabulary_)\n", | |
"\n", | |
"print(\"Japanese vocabylary size: {}\".format(enc_vocab_size))\n", | |
"print(\"English vocabylary size: {}\".format(dec_vocab_size))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"TRAIN_DATA_TFRECORD = \"/tmp/work/kyoto-train.cln.harf.tfrecord\"\n", | |
"TEST_DATA_TFRECORD = \"/tmp/work/kyoto-test.tfrecord\"\n", | |
"\n", | |
"def create_tfrecord(in_src_path, in_dst_path, out_path):\n", | |
" with open(in_src_path) as f1, open(in_dst_path) as f2, open(out_path, \"wb\") as f3:\n", | |
" src_sentence_ids_list = src_processor.transform(f1)\n", | |
" dst_sentence_ids_list = dst_processor.transform(f2)\n", | |
" writer = tf.python_io.TFRecordWriter(f3.name)\n", | |
" for src_sentence_ids, dst_sentence_ids in zip(src_sentence_ids_list,\\\n", | |
" dst_sentence_ids_list):\n", | |
" ex = make_example(src_sentence_ids, dst_sentence_ids)\n", | |
" writer.write(ex.SerializeToString())\n", | |
" writer.close()\n", | |
"\n", | |
"if not os.path.exists(TRAIN_DATA_TFRECORD):\n", | |
" create_tfrecord(TRAIN_SRC_DATA, TRAIN_DST_DATA, TRAIN_DATA_TFRECORD)\n", | |
"\n", | |
"if not os.path.exists(TEST_DATA_TFRECORD):\n", | |
" create_tfrecord(TEST_SRC_DATA, TEST_DST_DATA, TEST_DATA_TFRECORD)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Restoring parameters from /tmp/work/kftt-param/model.ckpt-119702\n", | |
"INFO:tensorflow:Starting standard services.\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"INFO:tensorflow:Starting queue runners.\n", | |
"Step 32000 / 164941, training loss 2.955790, test loss 3.610759, elapsed 155.859858s\n", | |
"Step 64000 / 164941, training loss 3.214088, test loss 3.303418, elapsed 154.051139s\n", | |
"Step 96000 / 164941, training loss 3.179242, test loss 3.179810, elapsed 155.419681s\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"Step 128000 / 164941, training loss 2.963753, test loss 3.454105, elapsed 155.601932s\n", | |
"Step 160000 / 164941, training loss 3.129966, test loss 3.428210, elapsed 156.284561s\n", | |
"Step 27059 / 164941, training loss 3.043962, test loss 3.286095, elapsed 155.217415s\n", | |
"Step 59059 / 164941, training loss 3.239690, test loss 3.404328, elapsed 154.764190s\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"Step 91059 / 164941, training loss 3.121370, test loss 3.437113, elapsed 155.464047s\n", | |
"Step 123059 / 164941, training loss 2.966356, test loss 3.358104, elapsed 155.541891s\n", | |
"Step 155059 / 164941, training loss 3.095634, test loss 3.266939, elapsed 156.200432s\n", | |
"Step 22118 / 164941, training loss 3.076107, test loss 3.277872, elapsed 154.762615s\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"Step 54118 / 164941, training loss 3.247306, test loss 3.299399, elapsed 155.225333s\n", | |
"Step 86118 / 164941, training loss 3.066488, test loss 3.385982, elapsed 155.067276s\n", | |
"Step 118118 / 164941, training loss 2.987248, test loss 3.351193, elapsed 155.344533s\n", | |
"Step 150118 / 164941, training loss 3.068154, test loss 3.274154, elapsed 155.232993s\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"Step 17177 / 164941, training loss 3.092844, test loss 3.673700, elapsed 155.766533s\n", | |
"Step 49177 / 164941, training loss 3.244971, test loss 3.511972, elapsed 155.381531s\n", | |
"Step 81177 / 164941, training loss 3.026749, test loss 3.442863, elapsed 155.906324s\n", | |
"Step 113177 / 164941, training loss 3.017391, test loss 3.648664, elapsed 155.843899s\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"Step 145177 / 164941, training loss 3.054037, test loss 3.330358, elapsed 155.839185s\n", | |
"Step 12236 / 164941, training loss 3.101636, test loss 3.361992, elapsed 154.582553s\n", | |
"Step 44236 / 164941, training loss 3.250136, test loss 3.398018, elapsed 155.170406s\n", | |
"Step 76236 / 164941, training loss 2.970292, test loss 3.257153, elapsed 154.940045s\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"Step 108236 / 164941, training loss 3.046880, test loss 3.396551, elapsed 156.018393s\n", | |
"Step 140236 / 164941, training loss 3.031431, test loss 3.311976, elapsed 155.616624s\n", | |
"Step 7295 / 164941, training loss 3.130935, test loss 3.317468, elapsed 155.316411s\n", | |
"Step 39295 / 164941, training loss 3.237224, test loss 3.201086, elapsed 155.159777s\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"Step 71295 / 164941, training loss 2.923450, test loss 3.576703, elapsed 155.130497s\n", | |
"Step 103295 / 164941, training loss 3.077371, test loss 3.190845, elapsed 156.320625s\n", | |
"Step 135295 / 164941, training loss 3.009074, test loss 3.435367, elapsed 155.930322s\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"Step 2354 / 164941, training loss 3.154707, test loss 3.442424, elapsed 155.266500s\n", | |
"Step 34354 / 164941, training loss 3.236584, test loss 3.706930, elapsed 154.725371s\n", | |
"Step 66354 / 164941, training loss 2.873634, test loss 3.417327, elapsed 155.217429s\n", | |
"Step 98354 / 164941, training loss 3.105311, test loss 3.387954, elapsed 156.290312s\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"Step 130354 / 164941, training loss 2.994145, test loss 3.139106, elapsed 156.059800s\n", | |
"Step 162354 / 164941, training loss 3.165069, test loss 3.503583, elapsed 154.644478s\n", | |
"Step 29413 / 164941, training loss 3.229005, test loss 3.447956, elapsed 154.803382s\n", | |
"Step 61413 / 164941, training loss 2.842835, test loss 3.323263, elapsed 155.632782s\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"Step 93413 / 164941, training loss 3.117907, test loss 3.292005, elapsed 155.991641s\n", | |
"Step 125413 / 164941, training loss 2.978633, test loss 3.525044, elapsed 156.009066s\n", | |
"Step 157413 / 164941, training loss 3.179128, test loss 3.470156, elapsed 154.802577s\n", | |
"Step 24472 / 164941, training loss 3.204044, test loss 3.392745, elapsed 154.570330s\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"Step 56472 / 164941, training loss 2.836560, test loss 3.612292, elapsed 155.624684s\n", | |
"Step 88472 / 164941, training loss 3.128556, test loss 3.367498, elapsed 155.831625s\n", | |
"Step 120472 / 164941, training loss 2.983065, test loss 3.390799, elapsed 155.163952s\n", | |
"Step 152472 / 164941, training loss 3.194086, test loss 3.197423, elapsed 154.228353s\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"Step 19531 / 164941, training loss 3.171189, test loss 3.395420, elapsed 155.323945s\n", | |
"Step 51531 / 164941, training loss 2.843658, test loss 3.587918, elapsed 155.822081s\n", | |
"Step 83531 / 164941, training loss 3.115164, test loss 3.615798, elapsed 155.848897s\n" | |
] | |
} | |
], | |
"source": [ | |
"tf.reset_default_graph()\n", | |
"\n", | |
"batch_size = 64\n", | |
"#num_epochs = 20\n", | |
"#num_epochs = 20\n", | |
"#num_epochs = 10\n", | |
"num_epochs = 10\n", | |
"embed_dim = 512\n", | |
"num_layers = 2\n", | |
"num_units = 512\n", | |
"learning_rate = 0.001\n", | |
"max_gradients_norm = 1.0\n", | |
"init_scale = 0.04\n", | |
"\n", | |
"keep_prob = tf.placeholder(tf.float32)\n", | |
"\n", | |
"def graph(\n", | |
" encoder_input,\n", | |
" encoder_sequence_length,\n", | |
" decoder_input,\n", | |
" decoder_sequence_length,\n", | |
" targets,\n", | |
" keep_prob,\n", | |
" is_inferring=False):\n", | |
" return seq2seq(\n", | |
" encoder_input=encoder_input,\n", | |
" decoder_input=decoder_input,\n", | |
" targets=targets,\n", | |
" encoder_sequence_length=encoder_sequence_length,\n", | |
" decoder_sequence_length=decoder_sequence_length,\n", | |
" keep_prob=keep_prob,\n", | |
" enc_vocab_size=enc_vocab_size,\n", | |
" dec_vocab_size=dec_vocab_size,\n", | |
" embed_dim=embed_dim,\n", | |
" num_layers=num_layers,\n", | |
" num_units=num_units,\n", | |
" learning_rate=learning_rate,\n", | |
" is_inferring=is_inferring,\n", | |
" max_gradients_norm=max_gradients_norm)\n", | |
"\n", | |
"initializer = tf.random_uniform_initializer(init_scale, init_scale)\n", | |
"\n", | |
"with tf.name_scope(\"Train\"):\n", | |
" with tf.variable_scope(\"inputs\"):\n", | |
" training_inputs = inputs(batch_size, TRAIN_DATA_TFRECORD, num_epochs)\n", | |
" with tf.variable_scope(\"model\", initializer=initializer):\n", | |
" loss, _, train_op, training_summary_op = graph(**{\n", | |
" **training_inputs,\n", | |
" **{\"keep_prob\": keep_prob}})\n", | |
"\n", | |
"with tf.name_scope(\"Test\"):\n", | |
" with tf.variable_scope(\"inputs\"): # inputsはreuse=None\n", | |
" test_inputs = inputs(batch_size, TEST_DATA_TFRECORD, None)\n", | |
" with tf.variable_scope(\"model\", reuse=True, initializer=initializer): # modelは当然reuse=True\n", | |
" test_loss, _, _, testing_summary_op = graph(**{\n", | |
" **test_inputs,\n", | |
" **{\"keep_prob\": keep_prob}})\n", | |
"\n", | |
"with tf.name_scope(\"Inferring\"):\n", | |
" with tf.variable_scope(\"inputs\"):\n", | |
" inferring_input_placeholders = inferring_inputs()\n", | |
" with tf.variable_scope(\"model\", reuse=True, initializer=initializer):\n", | |
" _, sample_id, _, _ = graph(**{\n", | |
" **inferring_input_placeholders,\n", | |
" **{\"keep_prob\": keep_prob, \"is_inferring\": True}})\n", | |
"\n", | |
"sv = tf.train.Supervisor(\n", | |
" logdir=logdir,\n", | |
" summary_op=None,\n", | |
" global_step=tf.train.get_global_step())\n", | |
"\n", | |
"with sv.managed_session() as sess:\n", | |
" start_time = time.time()\n", | |
" total_loss = 0.0\n", | |
" step = 1\n", | |
" while True:\n", | |
" if sv.should_stop():\n", | |
" break\n", | |
"\n", | |
" _loss, _, training_summary = sess.run([loss, train_op, training_summary_op],\n", | |
" feed_dict={keep_prob: 0.5})\n", | |
" \n", | |
" if np.isnan(_loss):\n", | |
" raise RuntimeError(\"loss is nan\")\n", | |
" \n", | |
" sv.summary_computed(sess, training_summary)\n", | |
" \n", | |
" total_loss += _loss\n", | |
" \n", | |
" if step % 20 == 0:\n", | |
" _test_loss, testing_summary = sess.run([test_loss, testing_summary_op],\n", | |
" feed_dict={keep_prob: 1.0})\n", | |
" sv.summary_computed(sess, testing_summary)\n", | |
"\n", | |
" if step % 500 == 0:\n", | |
" print(\"Step {} / 164941, training loss {:f}, test loss {:f}, elapsed {:f}s\".format(\n", | |
" (step * batch_size) % 164941, total_loss / 500, _test_loss, time.time() - start_time))\n", | |
" start_time = time.time()\n", | |
" total_loss = 0.0\n", | |
"\n", | |
" step += 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Restoring parameters from /tmp/work/kftt-param/model.ckpt-142870\n", | |
"INFO:tensorflow:Starting standard services.\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"INFO:tensorflow:Starting queue runners.\n", | |
"日本 の 水墨 画 を 一変 さ せ た 。\n", | |
" Truth: He revolutionized the Japanese ink painting .\n", | |
" Predicted: He changed the ink painting in Japan . <EOS>\n", | |
"\n", | |
"諱 は 「 等楊 ( とうよう ) 」 、 もしくは 「 拙宗 ( せっしゅう ) 」 と 号 し た 。\n", | |
" Truth: He was given the posthumous name \" Toyo \" or \" Sesshu ( 拙宗 ) . \"\n", | |
" Predicted: His imina ( personal name ) was <UNK> and <UNK> , and he was called ' <UNK> . ' <EOS>\n", | |
"\n", | |
"備中 国 に 生まれ 、 京都 ・ 相国 寺 に 入 っ て から 周防 国 に 移 る 。\n", | |
" Truth: Born in Bicchu Province , he moved to Suo Province after entering SShokoku-ji Temple in Kyoto .\n", | |
" Predicted: He was born in Bicchu Province , and moved to Suo Province in Kyoto . <EOS>\n", | |
"\n", | |
"その 後 遣明 使 に 随行 し て 中国 ( 明 ) に 渡 っ て 中国 の 水墨 画 を 学 ん だ 。\n", | |
" Truth: Later he accompanied a mission to Ming Dynasty China and learned Chinese ink painting .\n", | |
" Predicted: Later , he went to China and studied the Chinese ink painting in China . <EOS>\n", | |
"\n", | |
"作品 は 数 多 く 、 中国 風 の 山水 画 だけ で な く 人物 画 や 花鳥 画 も よ く し た 。\n", | |
" Truth: His works were many , including not only Chinese-style landscape paintings , but also portraits and pictures of flowers and birds .\n", | |
" Predicted: Many works were not produced , and they were painted in Chinese paintings and paintings of paintings and paintings . <EOS>\n", | |
"\n", | |
"大胆 な 構図 と 力強 い 筆線 は 非常 に 個性 的 な 画風 を 作り出 し て い る 。\n", | |
" Truth: His bold compositions and strong brush strokes constituted an extremely distinctive style .\n", | |
" Predicted: The <UNK> and bold bold bold style were characterized by <UNK> . <EOS>\n", | |
"\n", | |
"現存 する 作品 の うち 6 点 が 国宝 に 指定 さ れ て お り 、 日本 の 画家 の なか で も 別格 の 評価 を 受け て い る と いえ る 。\n", | |
" Truth: 6 of his extant works are designated national treasures . Indeed , he is considered to be extraordinary among Japanese painters .\n", | |
" Predicted: Among the existing works of existing works , the extant works are designated as national treasures , and it is recognized as a national treasure . <EOS>\n", | |
"\n", | |
"この ため 、 花鳥 図 屏風 など に 「 伝 雪舟 筆 」 さ れ る 作品 は 大変 多 い 。\n", | |
" Truth: For this reason , there are a great many artworks that are attributed to him , such as folding screens with pictures of flowers and that birds are painted on them .\n", | |
" Predicted: Therefore , many works of Sesshu are often known as ' Sesshu ' ( <UNK> painting ) and <UNK> ( folding screen ) . <EOS>\n", | |
"\n", | |
"真筆 で あ る か 専門 家 の 間 で も 意見 の 分かれ る もの も 多々 あ る 。\n", | |
" Truth: There are many works that even experts cannot agree if they are really his work or not .\n", | |
" Predicted: There are many opinions that <UNK> the <UNK> of the <UNK> . <EOS>\n", | |
"\n", | |
"弟子 に 、 秋月 、 宗 淵 、 等 春 ら が い る 。\n", | |
" Truth: His disciples include Shugetsu , Soen , and Toshun .\n", | |
" Predicted: His disciples include <UNK> , <UNK> , <UNK> , and <UNK> . <EOS>\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"def print_sample(\n", | |
" sess,\n", | |
" sample_id,\n", | |
" inferring_input_placeholders,\n", | |
" src_data_path,\n", | |
" dst_data_path,\n", | |
" n=10):\n", | |
" with open(src_data_path) as f1, open(dst_data_path) as f2:\n", | |
" for _ in range(n):\n", | |
" src_sentence = f1.readline()\n", | |
" dst_sentence = f2.readline()\n", | |
" src_ids = list(src_processor.transform([src_sentence]))\n", | |
" sampled = sess.run(sample_id, feed_dict={\n", | |
" inferring_input_placeholders[\"encoder_input\"]: src_ids,\n", | |
" inferring_input_placeholders[\"encoder_sequence_length\"]: [len(src_ids[0])],\n", | |
" keep_prob: 1.0\n", | |
" })\n", | |
" print(src_sentence,\n", | |
" \"Truth: {}\".format(dst_sentence),\n", | |
" \"Predicted: {}\".format(\" \".join(list(dst_processor.reverse([sampled[0]])))))\n", | |
" print()\n", | |
"\n", | |
"with sv.managed_session() as sess:\n", | |
" print_sample(\n", | |
" sess,\n", | |
" sample_id,\n", | |
" inferring_input_placeholders,\n", | |
" TRAIN_SRC_DATA,\n", | |
" TRAIN_DST_DATA)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Restoring parameters from /tmp/work/kftt-param/model.ckpt-142870\n", | |
"INFO:tensorflow:Starting standard services.\n", | |
"INFO:tensorflow:Saving checkpoint to path /tmp/work/kftt-param/model.ckpt\n", | |
"INFO:tensorflow:Starting queue runners.\n", | |
"Infobox Buddhist\n", | |
" Truth: Infobox Buddhist\n", | |
" Predicted: <UNK> <UNK> <EOS>\n", | |
"\n", | |
"道元 ( どうげん ) は 、 鎌倉 時代 初期 の 禅僧 。\n", | |
" Truth: Dogen was a Zen monk in the early Kamakura period .\n", | |
" Predicted: Dogen was a priest in the Kamakura period . <EOS>\n", | |
"\n", | |
"曹洞 宗 の 開祖 。\n", | |
" Truth: The founder of Soto Zen\n", | |
" Predicted: He was the founder of the Soto sect . <EOS>\n", | |
"\n", | |
"晩年 に 希 玄 と い う 異称 も 用い た 。\n", | |
" Truth: Later in his life he also went by the name Kigen .\n", | |
" Predicted: In the later years , he was also called <UNK> . <EOS>\n", | |
"\n", | |
"同宗旨 で は 高祖 と 尊称 さ れ る 。\n", | |
" Truth: Within the sect he is referred to by the honorary title Koso .\n", | |
" Predicted: In the <UNK> , the title of the title is given the title of the title of the title . <EOS>\n", | |
"\n", | |
"諡 は 、 仏性 伝 東 国師 、 承陽 大師 _ ( 僧 ) 。\n", | |
" Truth: Posthumously named Bussho Dento Kokushi , or Joyo-Daishi .\n", | |
" Predicted: His shigo ( posthumous name ) was <UNK> <UNK> , <UNK> ( <UNK> ) , and <UNK> ( a Buddhist priest ) . <EOS>\n", | |
"\n", | |
"一般 に は 道元 禅師 と 呼 ば れ る 。\n", | |
" Truth: He is generally called Dogen Zenji .\n", | |
" Predicted: He is generally called Dogen Zenji . <EOS>\n", | |
"\n", | |
"日本 に 歯磨き 洗面 、 食事 の 際 の 作法 や 掃除 の 習慣 を 広め た と い わ れ る 。\n", | |
" Truth: He is reputed to have been the one that spread the practices of tooth brushing , face washing , table manners and cleaning in Japan .\n", | |
" Predicted: It is said that the custom of <UNK> and <UNK> was spread in Japan and the manners of eating <UNK> and manners . <EOS>\n", | |
"\n", | |
"最初 に モウ ソウチク ( 孟宗 竹 ) を 持ち帰 っ た と する 説 も あ る 。\n", | |
" Truth: Another story has it that he was the first one to bring Moso-chiku ( Moso bamboo ) to Japan .\n", | |
" Predicted: There is a theory that the first <UNK> was <UNK> <UNK> ( <UNK> <UNK> ) . <EOS>\n", | |
"\n", | |
"道元 の 出生 に は 不明 の 点 が 多 い が 、 内 大臣 土御門 通親 ( 源 通親 あるいは 久我 通親 ) の 嫡流 に 生まれ た と する 点 で は 諸説 が 一致 し て い る 。\n", | |
" Truth: Though some points are unclear about Dogen 's birth , all accounts agree that he was born in the line of Udaijin ( Minister of the Right ) Michichika TSUCHIMIKADO ( MINAMOTO no Michichika or Michichika KOGA ) .\n", | |
" Predicted: There are many theories that the origin of Dogen was unknown , and the theory of the <UNK> was <UNK> by the <UNK> ( Michichika TSUCHIMIKADO ) , and the <UNK> of the <UNK> . <EOS>\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"with sv.managed_session() as sess:\n", | |
" print_sample(\n", | |
" sess,\n", | |
" sample_id,\n", | |
" inferring_input_placeholders,\n", | |
" TEST_SRC_DATA,\n", | |
" TEST_DST_DATA)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"65013 / 3436" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"15000 / 3436" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment