Skip to content

Instantly share code, notes, and snippets.

@akash-ch2812
Last active July 24, 2020 08:37
Show Gist options
  • Save akash-ch2812/4ff866b55bff52001ed308c543f7ba41 to your computer and use it in GitHub Desktop.
Save akash-ch2812/4ff866b55bff52001ed308c543f7ba41 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
%matplotlib inline
# method for generating captions
def generate_captions(model, image, tokenizer.word_index, max_caption_length, tokenizer.index_word):
# input is <start>
input_text = '<start>'
# keep generating words till we have encountered <end>
for i in range(max_caption_length):
seq = [tokenizer.word_index[w] for w in in_text.split() if w in list(tokenizer.word_index.keys())]
seq = pad_sequences([sequence], maxlen=max_caption_length)
prediction = model.predict([photo,sequence], verbose=0)
prediction = np.argmax(prediction)
word = tokenizer.index_word[prediction]
input_text += ' ' + word
if word == '<end>':
break
# remove <start> and <end> from output and return string
output = in_text.split()
output = output[1:-1]
output = ' '.join(output)
return output
# traverse through testing images to generate captions
count = 0
for key, value in test_image_features.items():
test_image = test_image_features[key]
test_image = np.expand_dims(test_image, axis=0)
final_caption = generate_captions(predictive_model, test_image, tokenizer.word_index, max_caption_len, tokenizer.index_word)
plt.figure(figsize=(7,7))
image = Image.open(image_path + "//" + key + ".jpg")
plt.imshow(image)
plt.title(final_caption)
count = count + 1
if count == 3:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment