Last active
September 18, 2023 05:35
-
-
Save mumunuu/fab393491bda9424424d28762573b889 to your computer and use it in GitHub Desktop.
Create Image Captioning Models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from textwrap import wrap | |
import matplotlib.pylab as plt | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow_datasets as tfds | |
import ssl | |
ssl._create_default_https_context = ssl._create_unverified_context | |
import tensorflow_hub as hub | |
from tensorflow.keras import Input | |
from tensorflow.keras.layers import ( | |
GRU, | |
Add, | |
AdditiveAttention, | |
Attention, | |
Concatenate, | |
Dense, | |
Embedding, | |
LayerNormalization, | |
Reshape, | |
StringLookup, | |
TextVectorization, | |
) | |
print(tf.version.VERSION) | |
# Change these to control the accuracy/speed | |
VOCAB_SIZE = 20000 # use fewer words to speed up convergence | |
ATTENTION_DIM = 512 # size of dense layer in Attention | |
WORD_EMBEDDING_DIM = 128 | |
# InceptionResNetV2 takes (299, 299, 3) image as inputs | |
# and return features in (8, 8, 1536) shape | |
FEATURE_EXTRACTOR = tf.keras.applications.inception_resnet_v2.InceptionResNetV2( | |
include_top=False, weights="imagenet" | |
) | |
IMG_HEIGHT = 299 | |
IMG_WIDTH = 299 | |
IMG_CHANNELS = 3 | |
FEATURES_SHAPE = (8, 8, 1536) | |
GCS_DIR = "gs://asl-public/data/tensorflow_datasets/" | |
BUFFER_SIZE = 1000 | |
def get_image_label(example): | |
caption = example["captions"]["text"][0] # only the first caption per image | |
img = example["image"] | |
img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH)) | |
img = img / 255 | |
return {"image_tensor": img, "caption": caption} | |
trainds = tfds.load("coco_captions", split="train", data_dir=GCS_DIR) | |
trainds = trainds.map( | |
get_image_label, num_parallel_calls=tf.data.AUTOTUNE | |
).shuffle(BUFFER_SIZE) | |
trainds = trainds.prefetch(buffer_size=tf.data.AUTOTUNE) | |
f, ax = plt.subplots(1, 4, figsize=(20, 5)) | |
for idx, data in enumerate(trainds.take(4)): | |
ax[idx].imshow(data["image_tensor"].numpy()) | |
caption = "\n".join(wrap(data["caption"].numpy().decode("utf-8"), 30)) | |
ax[idx].set_title(caption) | |
ax[idx].axis("off") | |
def add_start_end_token(data): | |
start = tf.convert_to_tensor("<start>") | |
end = tf.convert_to_tensor("<end>") | |
data["caption"] = tf.strings.join( | |
[start, data["caption"], end], separator=" " | |
) | |
return data | |
trainds = trainds.map(add_start_end_token) | |
MAX_CAPTION_LEN = 64 | |
# We will override the default standardization of TextVectorization to preserve | |
# "<>" characters, so we preserve the tokens for the <start> and <end>. | |
def standardize(inputs): | |
inputs = tf.strings.lower(inputs) | |
return tf.strings.regex_replace( | |
inputs, r"[!\"#$%&\(\)\*\+.,-/:;=?@\[\\\]^_`{|}~]?", "" | |
) | |
# Choose the most frequent words from the vocabulary & remove punctuation etc. | |
tokenizer = TextVectorization( | |
max_tokens=VOCAB_SIZE, | |
standardize=standardize, | |
output_sequence_length=MAX_CAPTION_LEN, | |
) | |
tokenizer.adapt(trainds.map(lambda x: x["caption"])) | |
tokenizer(["<start> This is a sentence <end>"]) | |
sample_captions = [] | |
for d in trainds.take(5): | |
sample_captions.append(d["caption"].numpy()) | |
print(tokenizer(sample_captions)) | |
for wordid in tokenizer([sample_captions[0]])[0]: | |
print(tokenizer.get_vocabulary()[wordid], end=" ") | |
# Lookup table: Word -> Index | |
word_to_index = StringLookup( | |
mask_token="", vocabulary=tokenizer.get_vocabulary() | |
) | |
# Lookup table: Index -> Word | |
index_to_word = StringLookup( | |
mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True | |
) | |
BATCH_SIZE = 32 | |
def create_ds_fn(data): | |
img_tensor = data["image_tensor"] | |
caption = tokenizer(data["caption"]) | |
target = tf.roll(caption, -1, 0) | |
zeros = tf.zeros([1], dtype=tf.int64) | |
target = tf.concat((target[:-1], zeros), axis=-1) | |
return (img_tensor, caption), target | |
batched_ds = ( | |
trainds.map(create_ds_fn) | |
.batch(BATCH_SIZE, drop_remainder=True) | |
.prefetch(buffer_size=tf.data.AUTOTUNE) | |
) | |
for (img, caption), label in batched_ds.take(2): | |
print(f"Image shape: {img.shape}") | |
print(f"Caption shape: {caption.shape}") | |
print(f"Label shape: {label.shape}") | |
print(caption[0]) | |
print(label[0]) | |
FEATURE_EXTRACTOR.trainable = False | |
image_input = Input(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)) | |
image_features = FEATURE_EXTRACTOR(image_input) | |
x = Reshape((FEATURES_SHAPE[0] * FEATURES_SHAPE[1], FEATURES_SHAPE[2]))( | |
image_features | |
) | |
encoder_output = Dense(ATTENTION_DIM, activation="relu")(x) | |
encoder = tf.keras.Model(inputs=image_input, outputs=encoder_output) | |
encoder.summary() | |
word_input = Input(shape=(MAX_CAPTION_LEN), name="words") | |
embed_x = Embedding(VOCAB_SIZE, ATTENTION_DIM)(word_input) | |
decoder_gru = GRU( | |
ATTENTION_DIM, | |
return_sequences=True, | |
return_state=True, | |
) | |
gru_output, gru_state = decoder_gru(embed_x) | |
decoder_atention = Attention() | |
context_vector = decoder_atention([gru_output, encoder_output]) | |
addition = Add()([gru_output, context_vector]) | |
layer_norm = LayerNormalization(axis=-1) | |
layer_norm_out = layer_norm(addition) | |
decoder_output_dense = Dense(VOCAB_SIZE) | |
decoder_output = decoder_output_dense(layer_norm_out) | |
decoder = tf.keras.Model( | |
inputs=[word_input, encoder_output], outputs=decoder_output | |
) | |
tf.keras.utils.plot_model(decoder) | |
decoder.summary() | |
image_caption_train_model = tf.keras.Model( | |
inputs=[image_input, word_input], outputs=decoder_output | |
) | |
loss_object = tf.keras.losses.SparseCategoricalCrossentropy( | |
from_logits=True, reduction="none" | |
) | |
def loss_function(real, pred): | |
loss_ = loss_object(real, pred) | |
# returns 1 to word index and 0 to padding (e.g. [1,1,1,1,1,0,0,0,0,...,0]) | |
mask = tf.math.logical_not(tf.math.equal(real, 0)) | |
mask = tf.cast(mask, dtype=tf.int32) | |
sentence_len = tf.reduce_sum(mask) | |
loss_ = loss_[:sentence_len] | |
return tf.reduce_mean(loss_, 1) | |
image_caption_train_model.compile( | |
optimizer="adam", | |
loss=loss_function, | |
) | |
# %%time | |
history = image_caption_train_model.fit(batched_ds, epochs=1) | |
gru_state_input = Input(shape=(ATTENTION_DIM), name="gru_state_input") | |
# Reuse trained GRU, but update it so that it can receive states. | |
gru_output, gru_state = decoder_gru(embed_x, initial_state=gru_state_input) | |
# Reuse other layers as well | |
context_vector = decoder_atention([gru_output, encoder_output]) | |
addition_output = Add()([gru_output, context_vector]) | |
layer_norm_output = layer_norm(addition_output) | |
decoder_output = decoder_output_dense(layer_norm_output) | |
# Define prediction Model with state input and output | |
decoder_pred_model = tf.keras.Model( | |
inputs=[word_input, gru_state_input, encoder_output], | |
outputs=[decoder_output, gru_state], | |
) | |
MINIMUM_SENTENCE_LENGTH = 5 | |
## Probabilistic prediction using the trained model | |
def predict_caption(filename): | |
gru_state = tf.zeros((1, ATTENTION_DIM)) | |
img = tf.image.decode_jpeg(tf.io.read_file(filename), channels=IMG_CHANNELS) | |
img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH)) | |
img = img / 255 | |
features = encoder(tf.expand_dims(img, axis=0)) | |
dec_input = tf.expand_dims([word_to_index("<start>")], 1) | |
result = [] | |
for i in range(MAX_CAPTION_LEN): | |
predictions, gru_state = decoder_pred_model( | |
[dec_input, gru_state, features] | |
) | |
# draws from log distribution given by predictions | |
top_probs, top_idxs = tf.math.top_k( | |
input=predictions[0][0], k=10, sorted=False | |
) | |
chosen_id = tf.random.categorical([top_probs], 1)[0].numpy() | |
predicted_id = top_idxs.numpy()[chosen_id][0] | |
result.append(tokenizer.get_vocabulary()[predicted_id]) | |
if predicted_id == word_to_index("<end>"): | |
return img, result | |
dec_input = tf.expand_dims([predicted_id], 1) | |
return img, result | |
filename = "./baseball.jpeg" # you can also try surf.jpeg | |
for i in range(5): | |
image, caption = predict_caption(filename) | |
print(" ".join(caption[:-1]) + ".") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment