Skip to content

Instantly share code, notes, and snippets.

@seanie12
Last active February 15, 2019 01:57
Show Gist options
  • Save seanie12/bdb2e335efc9d44c34ab234d295d7952 to your computer and use it in GitHub Desktop.
Save seanie12/bdb2e335efc9d44c34ab234d295d7952 to your computer and use it in GitHub Desktop.
sequence to sequence with attention in Keras
import tensorflow as tf
import os
import numpy as np
# settings for GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.9
sess = tf.Session(config=config)
# hyper-parameters
batch_size = 256
epochs = 30
latent_dim = 256
embedding_size = 128
num_samples = 10000
pad_token = "<PAD>"
data_path = "fra-eng/fra.txt"
input_texts = []
target_texts = []
input_words = set()
target_words = set()
with open(data_path, "r", encoding="utf-8") as f:
lines = f.read().split("\n")
for line in lines[: min(num_samples, len(lines) - 1)]:
input_text, target_text = line.split("\t")
# <GO> as the "start sequence" character
# <EOS> as "end sequence" character
target_text = "<GO> " + target_text + " <EOS>"
input_texts.append(input_text)
target_texts.append(target_text)
# construct the set of characters for each language
for word in input_text.split():
if word not in input_words:
input_words.add(word)
for word in target_text.split():
if word not in target_words:
target_words.add(word)
input_words = sorted(list(input_words))
target_words = sorted(list(target_words))
num_encoder_tokens = len(input_words)
num_decoder_tokens = len(target_words)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])
print('Number of samples:', len(input_texts))
print('Number of unique input tokens:', num_encoder_tokens)
print('Number of unique output tokens:', num_decoder_tokens)
print('Max sequence length for inputs:', max_encoder_seq_length)
print('Max sequence length for outputs:', max_decoder_seq_length)
# token2idx dictionary
input_token_index = {word: i for i, word in enumerate(input_words, start=1)}
target_token_index = {word: i for i, word in enumerate(target_words)}
# 0 for pad_token idx
input_token_index[pad_token] = 0
# construct zero numpy array
encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length)
, dtype=np.float32)
decoder_input_data = np.zeros((len(target_texts), max_decoder_seq_length),
dtype=np.float32)
decoder_target_data = np.zeros((len(target_texts), max_decoder_seq_length, num_decoder_tokens + 1),
dtype=np.float32)
# fill in the zero-numpy array
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
for t, word in enumerate(input_text.split()):
encoder_input_data[i, t] = input_token_index[word]
for t, word in enumerate(target_text.split()):
decoder_input_data[i, t] = target_token_index[word]
# decoder target is one time step ahead of decoder_input
if t > 0:
decoder_target_data[i, t - 1, target_token_index[word]] = 1
# encoder parts
encoder_inputs = tf.keras.Input(shape=[None], name="encoder_inputs")
# 0 for pad_token so total input dim is V + 1
encoder_embedding = tf.keras.layers.Embedding(input_dim=num_encoder_tokens + 1,
output_dim=embedding_size,
mask_zero=True,
name="encoder_embedding")
encoder_embedded = encoder_embedding(encoder_inputs)
encoder_lstm = tf.keras.layers.LSTM(latent_dim,
return_state=True,
return_sequences=True)
# bi-directional lstm
encoder = tf.keras.layers.Bidirectional(encoder_lstm)
encoder_outputs, fw_state_h, bw_state_h, fw_state_c, bw_state_c = encoder(encoder_embedded)
decoder_inputs = tf.keras.Input(shape=[None])
decoder_embedding = tf.keras.layers.Embedding(input_dim=num_decoder_tokens + 1,
output_dim=embedding_size,
mask_zero=True)
decoder_embedded = decoder_embedding(decoder_inputs)
decoder_lstm = tf.keras.layers.LSTM(latent_dim,
return_sequences=True,
return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_embedded)
# attention e^t_i = v_t tanh(W(h^{enc}_i; h^{dec}^t} + b)
def attention(inputs):
# inputs: [encoder_outputs, decoder_outputs]
# encoder_outputs : [batch, t, 2d], decoder_outputs :[b, k, d]
encoder_outputs = inputs[0]
decoder_outputs = inputs[1]
encoder_length = tf.shape(encoder_outputs)[1]
decoder_length = tf.shape(decoder_outputs)[1]
# encoder_hiddens : [b, t, 2d] -> [b, k, t, 2d]
# decoder_hiddens : [b, k, d] -> [b, k, t, d]
encoder_hiddens = tf.tile(tf.expand_dims(encoder_outputs, axis=1), [1, decoder_length, 1, 1])
decoder_hiddens = tf.tile(tf.expand_dims(decoder_outputs, axis=2), [1, 1, encoder_length, 1])
# hidden_input : [b,t,k,3d]
hidden_input = tf.concat([encoder_hiddens, decoder_hiddens], axis=-1)
# w := tanh(W[h_enc; h_dec] + b)
attention_hidden = tf.keras.layers.Dense(latent_dim,
activation=tf.nn.tanh,
use_bias=True)(hidden_input)
# v^t dot w
attention_score = tf.keras.layers.Dense(1)(attention_hidden)
attention_score = tf.squeeze(attention_score, axis=-1)
# attention mask
encoder_mask = tf.sign(tf.abs(tf.reduce_sum(encoder_outputs, axis=2)))
decoder_mask = tf.sign(tf.abs(tf.reduce_sum(decoder_outputs, axis=2)))
query_masks = tf.expand_dims(decoder_mask, 2) #[b, k, 1]
query_masks = tf.tile(query_masks, [1, 1, encoder_length]) # [b, k, t]
paddings = tf.ones_like(attention_score) * (-2 ** 32 + 1)
attention_score = tf.where(tf.equal(query_masks, 0), paddings, attention_score)
attention_score = tf.nn.softmax(attention_score, axis=-1)
key_mask = tf.expand_dims(encoder_mask, 1) # [b, t, 1]
attention_score *= key_mask
# [b, k , t] dot [b, t, 2d]
context_vectors = tf.matmul(attention_score, encoder_outputs)
augmented_outputs = tf.concat([decoder_outputs, context_vectors], axis=-1)
return augmented_outputs
attention_layer = tf.keras.layers.Lambda(attention)
decoder_outputs = attention_layer([encoder_outputs, decoder_outputs])
decoder_dense = tf.keras.layers.Dense(num_decoder_tokens + 1, activation=tf.nn.softmax)
decoder_outputs = decoder_dense(decoder_outputs)
model = tf.keras.Model(inputs=[encoder_inputs, decoder_inputs],
outputs=decoder_outputs)
model.summary()
model.compile(optimizer=tf.train.RMSPropOptimizer(1e-3),
loss=tf.keras.losses.categorical_crossentropy)
model.fit([encoder_input_data, decoder_input_data],
decoder_target_data,
batch_size=batch_size,
epochs=epochs,
validation_split=0.2)
model.save("s2s.h5")
# Next: inference mode (sampling).
# Here's the drill:
# 1) encode input and retrieve initial decoder state
# 2) run one step of decoder with this initial state
# and a "start of sequence" token as target.
# Output will be the next target token
# 3) Repeat with the current target token and current states
# Define sampling models
encoder_model = tf.keras.Model(inputs=encoder_inputs,
outputs=encoder_outputs)
# to compute context vectors, we need encoder_outputs for each decoder time step
encoder_hidden_outputs = tf.keras.Input(shape=(None, 2 * latent_dim))
decoder_state_input_h = tf.keras.Input(shape=(latent_dim,))
decoder_state_input_c = tf.keras.Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_embedded = decoder_embedding(decoder_inputs)
decoder_outputs, state_h, state_c = decoder_lstm(decoder_embedded,
initial_state=decoder_states_inputs)
decoder_outputs = attention_layer([encoder_hidden_outputs, decoder_outputs])
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = tf.keras.Model(inputs=[decoder_inputs, encoder_hidden_outputs] + decoder_states_inputs,
outputs=[decoder_outputs] + decoder_states)
reverse_input_word_index = {i: word for word, i in input_token_index.items()}
reverse_target_word_index = {i: word for word, i in target_token_index.items()}
def decode_sequence(input_seq):
encoder_outputs = encoder_model.predict(input_seq)
# for the first time step of decoder, initial state is zero vectors
zeros = np.zeros((max_decoder_seq_length, latent_dim))
states_value = [zeros, zeros]
target_seq = np.zeros((1, num_decoder_tokens + 1))
target_seq[0, 0] = target_token_index["<GO>"]
decoded_sentence = ""
stop_condition = False
while not stop_condition:
output_tokens, h, c = decoder_model.predict([target_seq, encoder_outputs] + states_value)
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = reverse_target_word_index[sampled_token_index]
decoded_sentence += " " + sampled_char
if sampled_char == "<EOS>" or len(decoded_sentence.split()) > max_decoder_seq_length:
stop_condition = True
target_seq = np.zeros((1, num_decoder_tokens + 1))
target_seq[0, 0] = target_token_index[sampled_char]
states_value = [h, c]
return decoded_sentence
with open("example.txt", "w", encoding="utf-8") as f:
for seq_index in range(100):
input_seq = encoder_input_data[seq_index: seq_index + 1]
decoded_sentence = decode_sequence(input_seq)
print("-")
input_text = input_texts[seq_index]
f.write(input_text + "\t" + decoded_sentence + "\n")
print("input sentence:", input_texts[seq_index])
print("decoded sentence:", decoded_sentence)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment