Skip to content

Instantly share code, notes, and snippets.

@lukas
Created October 30, 2019 18:35
Show Gist options
  • Save lukas/5c5e5a7b57343ed5fb2cfaac3f40f3f0 to your computer and use it in GitHub Desktop.
Save lukas/5c5e5a7b57343ed5fb2cfaac3f40f3f0 to your computer and use it in GitHub Desktop.
word_to_id = imdb.get_word_index()
word_to_id = {k: (v+3) for k, v in word_to_id.items()}
id_to_word = {value: key for key, value in word_to_id.items()}
id_to_word[0] = "" # Padding
id_to_word[1] = "" # Start token
id_to_word[2] = "�" # Unknown
id_to_word[3] = "" # End token
def decode(word):
return ' '.join(id_to_word[id] for id in word if id > 0)
class TextLogger(tf.keras.callbacks.Callback):
def __init__(self, inp, out):
self.inp = inp
self.out = out
def on_epoch_end(self, logs, epoch):
out = self.model.predict(self.inp)
data = [[decode(self.inp[i]), o, self.out[i]]
for i, o in enumerate(out)]
wandb.log({"text": wandb.Table(rows=data)}, commit=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment