Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created January 24, 2023 06:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sayakpaul/f474ffb01f0cdcc8ba239357965c3bca to your computer and use it in GitHub Desktop.
Save sayakpaul/f474ffb01f0cdcc8ba239357965c3bca to your computer and use it in GitHub Desktop.
"""
Usage: python run_semantic_segmentation.py
References:
* https://www.tensorflow.org/tutorials/images/segmentation#define_the_model
* https://huggingface.co/docs/transformers/main/en/tasks/semantic_segmentation
"""
import os
import uuid
from typing import Dict, List, Tuple
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import backend
from transformers import AutoImageProcessor, TFSegformerForSemanticSegmentation
from transformers.keras_callbacks import PushToHubCallback
IMAGE_SIZE = 512
MEAN = tf.constant([0.485, 0.456, 0.406])
STD = tf.constant([0.229, 0.224, 0.225])
BATCH_SIZE = 4
AUTO = tf.data.AUTOTUNE
MODEL_CKPT = "nvidia/mit-b0"
LR = 0.00006
EPOCHS = 10
def load_dataset():
dataset, info = tfds.load("oxford_iiit_pet:3.*.*", with_info=True)
return dataset, info
def normalize(input_image, input_mask) -> Tuple[tf.Tensor, tf.Tensor]:
input_image = tf.image.convert_image_dtype(input_image, tf.float32)
input_image = (input_image - MEAN) / tf.maximum(STD, backend.epsilon())
input_mask -= 1
return input_image, input_mask
def load_image(datapoint) -> Dict[str, tf.Tensor]:
input_image = tf.image.resize(datapoint["image"], (IMAGE_SIZE, IMAGE_SIZE))
input_mask = tf.image.resize(
datapoint["segmentation_mask"],
(IMAGE_SIZE, IMAGE_SIZE),
method="bilinear",
)
input_image, input_mask = normalize(input_image, input_mask)
input_image = tf.transpose(input_image, (2, 0, 1))
return {"pixel_values": input_image, "labels": tf.squeeze(input_mask)}
def prepare_datasets() -> Tuple[tf.data.Dataset, tf.data.Dataset]:
dataset, _ = load_dataset()
train_ds = (
dataset["train"]
.cache()
.shuffle(BATCH_SIZE * 10)
.map(load_image, num_parallel_calls=AUTO)
.batch(BATCH_SIZE)
.prefetch(AUTO)
)
test_ds = (
dataset["test"]
.map(load_image, num_parallel_calls=AUTO)
.batch(BATCH_SIZE)
.prefetch(AUTO)
)
print("Datasets prepared.")
return train_ds, test_ds
def load_and_compile_model() -> tf.keras.Model:
id2label = {0: "outer", 1: "inner", 2: "border"}
label2id = {label: id for id, label in id2label.items()}
num_labels = len(id2label)
model = TFSegformerForSemanticSegmentation.from_pretrained(
MODEL_CKPT,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True, # Will ensure the segmentation specific components are reinitialized.
)
optimizer = tf.keras.optimizers.Adam(learning_rate=LR)
model.compile(optimizer=optimizer)
print("Model initialized and compiled.")
return model
def prepare_callbacks(output_dir: str, dataset=None):
image_processor = AutoImageProcessor.from_pretrained(MODEL_CKPT)
model_name = MODEL_CKPT.split("/")[-1]
push_to_hub_model_id = f"{model_name}-finetuned-pets"
push_to_hub_callback = PushToHubCallback(
output_dir=output_dir,
hub_model_id=push_to_hub_model_id,
tokenizer=image_processor,
)
print("Callbacks prepared.")
return [push_to_hub_callback]
def train():
train_ds, test_ds = prepare_datasets()
callbacks = prepare_callbacks(output_dir="finetuned-pets")
model = load_and_compile_model()
history = model.fit(
train_ds,
validation_data=test_ds,
callbacks=callbacks,
epochs=EPOCHS,
)
print("Model trained.")
return model, history
if __name__ == "__main__":
_, _ = train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment