Skip to content

Instantly share code, notes, and snippets.

@KentaKudo
Created February 18, 2018 12:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save KentaKudo/fa4c36b37d7df018873de7217e630ae3 to your computer and use it in GitHub Desktop.
Save KentaKudo/fa4c36b37d7df018873de7217e630ae3 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import numpy as np
from keras.preprocessing.text import Tokenizer
from sklearn.model_selection import train_test_split
from keras.preprocessing.sequence import pad_sequences
def load_dataset(file_path):
tokenizer = Tokenizer(filters="")
texts = []
for line in open(file_path, 'r'):
texts.append("<s> " + line.strip() + " </s>")
tokenizer.fit_on_texts(texts)
return tokenizer.texts_to_sequences(texts), tokenizer
def decode_sequence(input_seq):
states_value = encoder_model.predict(input_seq)
bos_eos = tokenizer_j.texts_to_sequences(["<s>", "</s>"])
target_seq = np.array(bos_eos[0])
output_seq = bos_eos[0]
while True:
output_tokens, h, c = decoder_model.predict(
[target_seq] + states_value
)
sampled_token_index = [np.argmax(output_tokens[0, -1, :])]
output_seq += sampled_token_index
if (sampled_token_index == bos_eos[1] or len(output_seq) > 1000):
break
target_seq = np.array(sampled_token_index)
states_value = [h, c]
return output_seq
train_X, tokenizer_e = load_dataset('tanaka_corpus_e.txt')
train_Y, tokenizer_j = load_dataset('tanaka_corpus_j.txt')
train_X, test_X, train_Y, test_Y = train_test_split(train_X, train_Y, test_size=0.02, random_state=42)
train_X = pad_sequences(train_X, padding='post')
train_Y = pad_sequences(train_Y, padding='post')
seqX_len = len(train_X[0])
seqY_len = len(train_Y[0])
word_num_e = len(tokenizer_e.word_index) + 1
word_num_j = len(tokenizer_j.word_index) + 1
from keras.models import Model
from keras.layers import Input, Embedding, Dense, LSTM
emb_dim = 256
hid_dim = 256
encoder_inputs = Input(shape=(seqX_len,))
encoder_embedded = Embedding(word_num_e, emb_dim, mask_zero=True)(encoder_inputs)
encoder = LSTM(hid_dim, return_state=True)
_, state_h, state_c = encoder(encoder_embedded)
encoder_states = [state_h, state_c]
decoder_inputs = Input(shape=(seqY_len,))
decoder_embedding = Embedding(word_num_j, emb_dim)
decoder_embedded = decoder_embedding(decoder_inputs)
decoder = LSTM(hid_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder(decoder_embedded, initial_state=encoder_states)
decoder_dense = Dense(word_num_j, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy')
decoder_target_data = np.hstack((train_Y[:, 1:], np.zeros((len(train_Y),1), dtype=np.int32)))
model.fit([train_X, train_Y], np.expand_dims(decoder_target_data, -1), batch_size=128, epochs=1, verbose=2, validation_split=0.2)
# outputs encoder_states no matter what the inputs are.
encoder_model = Model(encoder_inputs, encoder_states)
decoder_state_input_h = Input(shape=(hid_dim,)) # will receive encoder_state_h as input
decoder_state_input_c = Input(shape=(hid_dim,)) # will receive encoder_state_c as input
decoder_state_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_inputs = Input(shape=(1,))
decoder_embedded = decoder_embedding(decoder_inputs)
decoder_outputs, state_h, state_c = decoder(
decoder_embedded, initial_state=decoder_state_inputs
)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model(
[decoder_inputs] + decoder_state_inputs,
[decoder_outputs] + decoder_states
)
detokenizer_e = dict(map(reversed, tokenizer_e.word_index.items()))
detokenizer_j = dict(map(reversed, tokenizer_j.word_index.items()))
detokenizer_j[0] = '' # paddingが出力されることがあったためズルをする
input_seq = pad_sequences([test_X[0]], seqX_len, padding='post')
print(' '.join([detokenizer_e[i] for i in test_X[0]]))
print(' '.join([detokenizer_j[i] for i in decode_sequence(input_seq)]))
print(' '.join([detokenizer_j[i] for i in test_Y[0]]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment