This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
look_left_only_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
look_left_only_mask = tf.constant([[1] * i + [0] * (seq_len - i) for i in range(seq_len)], dtype=tf.float32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def call(self, query, value, mask=None): | |
# query has shape (batch, query_len, model_size) | |
# value has shape (batch, value_len, model_size) | |
heads = [] | |
for i in range(self.h): | |
score = tf.matmul(self.wq[i](query), self.wk[i](value), transpose_b=True) | |
# Here we scale the score as described in the paper | |
score /= tf.math.sqrt(tf.dtypes.cast(self.key_size, tf.float32)) | |
# score has shape (batch, query_len, value_len) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def call(self, sequence): | |
embed_out = self.embedding(sequence) | |
embed_out += pes[:sequence.shape[1], :] | |
sub_in = embed_out | |
for i in range(self.num_layers): | |
sub_out = self.attention[i](sub_in, sub_in) | |
sub_out = sub_in + sub_out | |
sub_out = self.attention_norm[i](sub_out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def call(self, sequence, padding_mask): | |
# padding_mask will have the same shape as the input sequence | |
# padding_mask will be used in the Decoder too | |
# so we need to create it outside the Encoder | |
embed_out = self.embedding(sequence) | |
embed_out += pes[:sequence.shape[1], :] | |
sub_in = embed_out | |
for i in range(self.num_layers): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
What a ridiculous concept ! | |
[[8, 5, 21, 22, 23]] | |
ce concept ridicule ! <end> | |
Your idea is not entirely crazy . | |
[[24, 25, 6, 26, 27, 28, 1]] | |
votre idee n est pas completement pas . . . <end> | |
A man s worth lies in what he is . | |
[[5, 29, 30, 31, 32, 9, 8, 7, 6, 1]] | |
la valeur d un homme reside dans ce qu il est . <end> | |
What he did is very wrong . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Epoch 50 Loss 0.0141 | |
Average elapsed time: 4.84s | |
I can t believe you re giving up . | |
[[16, 11, 12, 95, 2, 96, 14, 97, 1]] | |
je n arrive pas a croire que vous abandonniez . <end> | |
Epoch 60 Loss 0.0063 | |
Average elapsed time: 4.66s | |
What a ridiculous concept ! | |
[[8, 5, 21, 22, 23]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Epoch 10 Loss 2.2374 | |
Average elapsed time: 8.79s | |
He acted like he owned the place . | |
[[7, 82, 83, 7, 84, 4, 85, 1]] | |
l vous vous vous vous les les . <end> | |
Epoch 20 Loss 2.0102 | |
Average elapsed time: 6.29s | |
Did you plant pumpkins this year ? | |
[[13, 2, 56, 57, 19, 58, 3]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
NUM_EPOCHS = 100 | |
start_time = time.time() | |
for e in range(NUM_EPOCHS): | |
for batch, (source_seq, target_seq_in, target_seq_out) in enumerate(dataset.take(-1)): | |
loss = train_step(source_seq, target_seq_in, | |
target_seq_out) | |
print('Epoch {} Loss {:.4f}'.format( | |
e + 1, loss.numpy())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict(test_source_text=None): | |
# If test sentence is not provided | |
# randomly pick up one from the training data | |
if test_source_text is None: | |
test_source_text = raw_data_en[np.random.choice(len(raw_data_en))] | |
print(test_source_text) | |
# Tokenize the test sentence to obtain source sequence | |
test_source_seq = en_tokenizer.texts_to_sequences([test_source_text]) |