Created
December 4, 2018 09:54
-
-
Save NeuroWhAI/30ff6189e5f1f38e5f582de07129fcb0 to your computer and use it in GitHub Desktop.
(Keras) Seq2Seq with Attention!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras import layers, models | |
from keras import datasets | |
from keras import backend as K | |
from keras.utils import plot_model | |
import matplotlib | |
from matplotlib import ticker | |
import matplotlib.pyplot as plt | |
import numpy as np | |
#from IPython.display import Image, display | |
def get_data(file_id): | |
""" | |
구글 드라이브에서 file_id에 해당하는 파일을 가져와 읽습니다. | |
Colab에서 돌리기 위해 필요합니다. | |
""" | |
# Install the PyDrive wrapper & import libraries. | |
# This only needs to be done once per notebook. | |
from pydrive.auth import GoogleAuth | |
from pydrive.drive import GoogleDrive | |
from google.colab import auth | |
from oauth2client.client import GoogleCredentials | |
# Authenticate and create the PyDrive client. | |
# This only needs to be done once per notebook. | |
auth.authenticate_user() | |
gauth = GoogleAuth() | |
gauth.credentials = GoogleCredentials.get_application_default() | |
drive = GoogleDrive(gauth) | |
# Download a file based on its file ID. | |
# | |
# A file ID looks like: laggVyWshwcyP6kEI-y_W3P8D26sz | |
downloaded = drive.CreateFile({'id': file_id}) | |
return downloaded.GetContentString() | |
# 학습 정보 | |
batch_size = 100 | |
epochs = 24 | |
latent_dim = 256 | |
num_samples = 4096 | |
# 문장 벡터화 | |
input_texts = [] | |
target_texts = [] | |
input_characters = set() | |
target_characters = set() | |
data_set = { | |
'kor': '12UL7MH38rxRFskhYuj1B9zU7wUXgv8Q7', | |
'jpn': '1ULdUnurb_DEDTzJFaycdWHWkYWqPu4Bo', | |
'deu': '1OhP8vPwTGrLyweo0HuLhhGpe0VRWn4k6', | |
} | |
lines = get_data(data_set['jpn']).split('\n') | |
for line in lines[: min(num_samples, len(lines) - 1)]: | |
input_text, target_text = line.split('\t') | |
#target_text = input_text # 임시 | |
# "\t"문자를 시작 문자로, "\n"문자를 종료 문자로 사용. | |
target_text = '\t' + target_text + '\n' | |
input_texts.append(input_text) | |
target_texts.append(target_text) | |
# 문자 집합 생성 | |
for char in input_text: | |
if char not in input_characters: | |
input_characters.add(char) | |
for char in target_text: | |
if char not in target_characters: | |
target_characters.add(char) | |
# 학습 데이터 개수 | |
num_samples = len(input_texts) | |
input_characters = sorted(list(input_characters)) | |
target_characters = sorted(list(target_characters)) | |
num_encoder_tokens = len(input_characters) | |
num_decoder_tokens = len(target_characters) | |
max_encoder_seq_length = max([len(txt) for txt in input_texts]) | |
max_decoder_seq_length = max([len(txt) for txt in target_texts]) | |
print('Number of samples:', num_samples) | |
print('Number of unique input tokens:', num_encoder_tokens) | |
print('Number of unique output tokens:', num_decoder_tokens) | |
print('Max sequence length for inputs:', max_encoder_seq_length) | |
print('Max sequence length for outputs:', max_decoder_seq_length) | |
# 문자 -> 숫자 변환용 사전 | |
input_token_index = dict( | |
[(char, i) for i, char in enumerate(input_characters)]) | |
target_token_index = dict( | |
[(char, i) for i, char in enumerate(target_characters)]) | |
# 학습에 사용할 데이터를 담을 3차원 배열 | |
encoder_input_data = np.zeros( | |
(num_samples, max_encoder_seq_length, num_encoder_tokens), | |
dtype='float32') | |
decoder_input_data = np.zeros( | |
(num_samples, max_decoder_seq_length, num_decoder_tokens), | |
dtype='float32') | |
decoder_target_data = np.zeros( | |
(num_samples, max_decoder_seq_length, num_decoder_tokens), | |
dtype='float32') | |
# 문장을 문자 단위로 원 핫 인코딩하면서 학습용 데이터를 만듬 | |
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)): | |
for t, char in enumerate(input_text): | |
encoder_input_data[i, t, input_token_index[char]] = 1. | |
for t, char in enumerate(target_text): | |
decoder_input_data[i, t, target_token_index[char]] = 1. | |
if t > 0: | |
decoder_target_data[i, t - 1, target_token_index[char]] = 1. | |
# 숫자 -> 문자 변환용 사전 | |
reverse_input_char_index = dict( | |
(i, char) for char, i in input_token_index.items()) | |
reverse_target_char_index = dict( | |
(i, char) for char, i in target_token_index.items()) | |
def RepeatVectorLayer(rep, axis): | |
return layers.Lambda(lambda x: K.repeat_elements(K.expand_dims(x, axis), rep, axis), | |
lambda x: tuple((x[0],) + x[1:axis] + (rep,) + x[axis:])) | |
# 인코더 생성 | |
encoder_inputs = layers.Input(shape=(max_encoder_seq_length, num_encoder_tokens)) | |
encoder = layers.GRU(latent_dim, return_sequences=True, return_state=True) | |
encoder_outputs, state_h = encoder(encoder_inputs) | |
# 디코더 생성. | |
decoder_inputs = layers.Input(shape=(max_decoder_seq_length, num_decoder_tokens)) | |
decoder = layers.GRU(latent_dim, return_sequences=True, return_state=True) | |
decoder_outputs, _ = decoder(decoder_inputs, initial_state=state_h) | |
# 어텐션 매커니즘. | |
repeat_d_layer = RepeatVectorLayer(max_encoder_seq_length, 2) | |
repeat_d = repeat_d_layer(decoder_outputs) | |
repeat_e_layer = RepeatVectorLayer(max_decoder_seq_length, 1) | |
repeat_e = repeat_e_layer(encoder_outputs) | |
concat_for_score_layer = layers.Concatenate(axis=-1) | |
concat_for_score = concat_for_score_layer([repeat_d, repeat_e]) | |
dense1_t_score_layer = layers.Dense(latent_dim // 2, activation='tanh') | |
dense1_score_layer = layers.TimeDistributed(dense1_t_score_layer) | |
dense1_score = dense1_score_layer(concat_for_score) | |
dense2_t_score_layer = layers.Dense(1) | |
dense2_score_layer = layers.TimeDistributed(dense2_t_score_layer) | |
dense2_score = dense2_score_layer(dense1_score) | |
dense2_score = layers.Reshape((max_decoder_seq_length, max_encoder_seq_length))(dense2_score) | |
softmax_score_layer = layers.Softmax(axis=-1) | |
softmax_score = softmax_score_layer(dense2_score) | |
repeat_score_layer = RepeatVectorLayer(latent_dim, 2) | |
repeat_score = repeat_score_layer(softmax_score) | |
permute_e = layers.Permute((2, 1))(encoder_outputs) | |
repeat_e_layer = RepeatVectorLayer(max_decoder_seq_length, 1) | |
repeat_e = repeat_e_layer(permute_e) | |
attended_mat_layer = layers.Multiply() | |
attended_mat = attended_mat_layer([repeat_score, repeat_e]) | |
context_layer = layers.Lambda(lambda x: K.sum(x, axis=-1), | |
lambda x: tuple(x[:-1])) | |
context = context_layer(attended_mat) | |
concat_context_layer = layers.Concatenate(axis=-1) | |
concat_context = concat_context_layer([context, decoder_outputs]) | |
attention_dense_output_layer = layers.Dense(latent_dim, activation='tanh') | |
attention_output_layer = layers.TimeDistributed(attention_dense_output_layer) | |
attention_output = attention_output_layer(concat_context) | |
decoder_dense = layers.Dense(num_decoder_tokens, activation='softmax') | |
decoder_outputs = decoder_dense(attention_output) | |
# 모델 생성 | |
model = models.Model([encoder_inputs, decoder_inputs], decoder_outputs) | |
choice = input("Load weights?") | |
if choice == 'y' or choice == 'Y': | |
model.load_weights('att_seq2seq_weights.h5') | |
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', | |
metrics=['acc']) | |
#model.summary() | |
#plot_model(model, show_shapes=True, to_file='model.png') | |
#display(Image(filename='model.png')) | |
choice = input("Train?") | |
if choice == 'y' or choice == 'Y': | |
# 학습 | |
history = model.fit([encoder_input_data, decoder_input_data], decoder_target_data, | |
batch_size=batch_size, | |
epochs=epochs, | |
validation_split=0.2, | |
verbose=2) | |
model.save_weights('att_seq2seq_weights.h5') | |
# 손실 그래프 | |
plt.plot(history.history['loss'], 'y', label='train loss') | |
plt.plot(history.history['val_loss'], 'r', label='val loss') | |
plt.legend(loc='upper left') | |
plt.show() | |
# 정확도 그래프 | |
plt.plot(history.history['acc'], 'y', label='train acc') | |
plt.plot(history.history['val_acc'], 'r', label='val acc') | |
plt.legend(loc='upper left') | |
plt.show() | |
# 어텐션 검증 | |
test_data_num = 0 | |
test_max_len = 0 | |
for i, s in enumerate(input_texts): | |
if len(s) > test_max_len: | |
test_max_len = len(s) | |
test_data_num = i | |
test_enc_input = encoder_input_data[test_data_num].reshape( | |
(1, max_encoder_seq_length, num_encoder_tokens)) | |
test_dec_input = decoder_input_data[test_data_num].reshape( | |
(1, max_decoder_seq_length, num_decoder_tokens)) | |
attention_layer = softmax_score_layer | |
func = K.function([encoder_inputs, decoder_inputs] + [K.learning_phase()], [attention_layer.output]) | |
score_values = func([test_enc_input, test_dec_input, 1.0])[0] | |
score_values = score_values.reshape((max_decoder_seq_length, max_encoder_seq_length)) | |
score_values = score_values[:len(target_texts[test_data_num])-1, :len(input_texts[test_data_num])] | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
cax = ax.matshow(score_values, interpolation='nearest') | |
fig.colorbar(cax) | |
test_enc_names = [] | |
for vec in test_enc_input[0]: | |
sampled_token_index = np.argmax(vec) | |
sampled_char = reverse_input_char_index[sampled_token_index] | |
test_enc_names.append(sampled_char) | |
test_dec_names = [] | |
for vec in test_dec_input[0]: | |
sampled_token_index = np.argmax(vec) | |
sampled_char = reverse_target_char_index[sampled_token_index] | |
test_dec_names.append(sampled_char) | |
print(test_dec_names[1:len(target_texts[test_data_num])]) | |
ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) | |
ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) | |
ax.set_yticklabels(['']+test_dec_names[1:-1] + ['<END>']) | |
ax.set_xticklabels(['']+test_enc_names) | |
plt.show() | |
# 추론(테스트) | |
# 추론 모델 생성 | |
encoder_model = models.Model(encoder_inputs, [encoder_outputs, state_h]) | |
encoder_outputs_input = layers.Input(shape=(max_encoder_seq_length, latent_dim)) | |
decoder_inputs = layers.Input(shape=(1, num_decoder_tokens)) | |
decoder_state_input_h = layers.Input(shape=(latent_dim,)) | |
decoder_outputs, decoder_h = decoder(decoder_inputs, initial_state=decoder_state_input_h) | |
repeat_d_layer = RepeatVectorLayer(max_encoder_seq_length, 2) | |
repeat_d = repeat_d_layer(decoder_outputs) | |
repeat_e_layer = RepeatVectorLayer(1, axis=1) | |
repeat_e = repeat_e_layer(encoder_outputs_input) | |
concat_for_score_layer = layers.Concatenate(axis=-1) | |
concat_for_score = concat_for_score_layer([repeat_d, repeat_e]) | |
dense1_score_layer = layers.TimeDistributed(dense1_t_score_layer) | |
dense1_score = dense1_score_layer(concat_for_score) | |
dense2_score_layer = layers.TimeDistributed(dense2_t_score_layer) | |
dense2_score = dense2_score_layer(dense1_score) | |
dense2_score = layers.Reshape((1, max_encoder_seq_length))(dense2_score) | |
softmax_score_layer = layers.Softmax(axis=-1) | |
softmax_score = softmax_score_layer(dense2_score) | |
repeat_score_layer = RepeatVectorLayer(latent_dim, 2) | |
repeat_score = repeat_score_layer(softmax_score) | |
permute_e = layers.Permute((2, 1))(encoder_outputs_input) | |
repeat_e_layer = RepeatVectorLayer(1, axis=1) | |
repeat_e = repeat_e_layer(permute_e) | |
attended_mat_layer = layers.Multiply() | |
attended_mat = attended_mat_layer([repeat_score, repeat_e]) | |
context_layer = layers.Lambda(lambda x: K.sum(x, axis=-1), | |
lambda x: tuple(x[:-1])) | |
context = context_layer(attended_mat) | |
concat_context_layer = layers.Concatenate(axis=-1) | |
concat_context = concat_context_layer([context, decoder_outputs]) | |
attention_output_layer = layers.TimeDistributed(attention_dense_output_layer) | |
attention_output = attention_output_layer(concat_context) | |
decoder_att_outputs = decoder_dense(attention_output) | |
decoder_model = models.Model([decoder_inputs, decoder_state_input_h, encoder_outputs_input], | |
[decoder_outputs, decoder_h, decoder_att_outputs]) | |
#decoder_model.summary() | |
#plot_model(decoder_model, show_shapes=True, to_file='decoder_model.png') | |
#display(Image(filename='decoder_model.png')) | |
def decode_sequence(input_seq): | |
# 입력 문장을 인코딩 | |
enc_outputs, states_value = encoder_model.predict(input_seq) | |
# 디코더의 입력으로 쓸 단일 문자 | |
target_seq = np.zeros((1, 1, num_decoder_tokens)) | |
# 첫 입력은 시작 문자인 '\t'로 설정 | |
target_seq[0, 0, target_token_index['\t']] = 1. | |
# 문장 생성 | |
stop_condition = False | |
decoded_sentence = '' | |
while not stop_condition: | |
# 이전의 출력, 상태를 디코더에 넣어서 새로운 출력, 상태를 얻음 | |
# 이전 문자와 상태로 다음 문자와 상태를 얻는다고 보면 됨. | |
dec_outputs, h, output_tokens = decoder_model.predict( | |
[target_seq, states_value, enc_outputs]) | |
# 사전을 사용해서 원 핫 인코딩 출력을 실제 문자로 변환 | |
sampled_token_index = np.argmax(output_tokens[0, -1, :]) | |
sampled_char = reverse_target_char_index[sampled_token_index] | |
decoded_sentence += sampled_char | |
# 종료 문자가 나왔거나 문장 길이가 한계를 넘으면 종료 | |
if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length): | |
stop_condition = True | |
# 디코더의 다음 입력으로 쓸 데이터 갱신 | |
target_seq = np.zeros((1, 1, num_decoder_tokens)) | |
target_seq[0, 0, sampled_token_index] = 1. | |
states_value = h | |
return decoded_sentence | |
for seq_index in range(30): | |
input_seq = encoder_input_data[seq_index: seq_index + 1] | |
decoded_sentence = decode_sequence(input_seq) | |
print('"{}" -> "{}"'.format(input_texts[seq_index], decoded_sentence.strip())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment