Skip to content

Instantly share code, notes, and snippets.

@burnpiro
Last active May 13, 2022 13:48
Show Gist options
  • Save burnpiro/c3835a1f914545f2034f4190b1e83153 to your computer and use it in GitHub Desktop.
Save burnpiro/c3835a1f914545f2034f4190b1e83153 to your computer and use it in GitHub Desktop.
Tensorflow 2 custom dataset Sequence
import tensorflow as tf
from data.data_generator import DataGenerator
from config import cfg
## Create train dataset
train_datagen = DataGenerator(file_path=cfg.TRAIN.DATA_PATH, config_path=cfg.TRAIN.ANNOTATION_PATH)
## Create validation dataset
val_generator = DataGenerator(file_path=cfg.TEST.DATA_PATH, config_path=cfg.TEST.ANNOTATION_PATH, debug=False)
model.fit_generator(generator=train_datagen,
epochs=cfg.TRAIN.EPOCHS,
callbacks=[# your callbacks for TF],
shuffle=True,
verbose=1)
from easydict import EasyDict
__C = EasyDict()
cfg = __C
# create NN dict
__C.NN = EasyDict()
__C.NN.INPUT_SIZE = 224
# create Train options dict
__C.TRAIN = EasyDict()
__C.TRAIN.DATA_PATH = "./data/WIDER_train/images/"
__C.TRAIN.ANNOTATION_PATH = "./data/wider_face_split/wider_face_train_bbx_gt.txt"
__C.TRAIN.BATCH_SIZE = 16
# create VAL options dict
__C.VAL = EasyDict()
__C.VAL.DATA_PATH = "./data/WIDER_val/images/"
__C.VAL.ANNOTATION_PATH = "./data/wider_face_split/wider_face_val_bbx_gt.txt"
import os
import sys
import math
import numpy as np
import tensorflow as tf
from config import cfg
# Input: [x0, y0, w, h, blur, expression, illumination, invalid, occlusion, pose]
# Output: x0, y0, w, h
def get_box(data):
x0 = int(data[0])
y0 = int(data[1])
w = int(data[2])
h = int(data[3])
return x0, y0, w, h
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, file_path, config_path, debug=False):
self.boxes = []
self.debug = debug
self.data_path = file_path
if not os.path.isfile(config_path):
print("File path {} does not exist. Exiting...".format(config_path))
sys.exit()
if not os.path.isdir(file_path):
print("Images folder path {} does not exist. Exiting...".format(file_path))
sys.exit()
with open(config_path) as fp:
image_name = fp.readline()
cnt = 1
while image_name:
num_of_obj = int(fp.readline())
for i in range(num_of_obj):
obj_box = fp.readline().split(' ')
x0, y0, w, h = get_box(obj_box)
if w == 0:
# remove boxes with no width
continue
if h == 0:
# remove boxes with no height
continue
self.boxes.append((image_name.strip(), x0, y0, w, h))
if num_of_obj == 0:
obj_box = fp.readline().split(' ')
x0, y0, w, h = get_box(obj_box)
self.boxes.append((image_name.strip(), x0, y0, w, h))
image_name = fp.readline()
cnt += 1
def __len__(self):
return math.ceil(len(self.boxes) / cfg.TRAIN.BATCH_SIZE)
def __getitem__(self, idx):
boxes = self.boxes[idx * cfg.TRAIN.BATCH_SIZE:(idx + 1) * cfg.TRAIN.BATCH_SIZE]
batch_images = np.zeros((len(boxes), cfg.NN.INPUT_SIZE, cfg.NN.INPUT_SIZE, 3), dtype=np.float32)
batch_boxes = np.zeros((len(boxes), cfg.NN.GRID_SIZE, cfg.NN.GRID_SIZE, 5), dtype=np.float32)
for i, row in enumerate(boxes):
path, x0, y0, w, h = row
proc_image = tf.keras.preprocessing.image.load_img(self.data_path + path)
image_width = proc_image.width
image_height = proc_image.height
proc_image = tf.keras.preprocessing.image.load_img(self.data_path + path,
target_size=(cfg.NN.INPUT_SIZE, cfg.NN.INPUT_SIZE))
proc_image = tf.keras.preprocessing.image.img_to_array(proc_image)
proc_image = np.expand_dims(proc_image, axis=0)
proc_image - tf.keras.applications.mobilenet_v2.preprocess_input(proc_image)
batch_images[i] = proc_image
# make sure none of the points is out of image border
x0 = max(x0, 0)
y0 = max(y0, 0)
x0 = min(x0, image_width)
y0 = min(y0, image_height)
x_c = (cfg.NN.GRID_SIZE / image_width) * x0
y_c = (cfg.NN.GRID_SIZE / image_height) * y0
floor_y = math.floor(y_c) # handle case when x i on the corner
floor_x = math.floor(x_c) # handle case when y i on the corner
batch_boxes[i, floor_y, floor_x, 0] = h / image_height
batch_boxes[i, floor_y, floor_x, 1] = w / image_width
batch_boxes[i, floor_y, floor_x, 2] = y_c - floor_y
batch_boxes[i, floor_y, floor_x, 3] = x_c - floor_x
batch_boxes[i, floor_y, floor_x, 4] = 1
return batch_images, batch_boxes
@felattaoui
Copy link

felattaoui commented Nov 22, 2020

Hello I tried to implement this but I do not understand why do you add : the line.stripe() inside the append :
self.boxes.append((line.strip(), x0, y0, w, h)
It seems that this "line" variable does not exist.
Moreover I think that the variable image_name is not updated at the end of the while loop.

Thank you for your help

@burnpiro
Copy link
Author

Hello I tried to implement this but I do not understand why do you add : the "line.strip()" inside the append :
self.boxes.append((line.strip(), x0, y0, w, h))
It seems that this "line" variable does not exist.

Hi, you're completely right, this file variable suppose to be image_name. I was copying the code from my generator and decided to replace line with sth more meaningful. .strip() is there just in case someone has left an empty space but you probably can just remove stripping.

Thanks for pointing that out 👍

@felattaoui
Copy link

Thank you very much for your quick answer.
Can you confirm that in the code de value of image_name should be updated also ?
Do you have another article for transfer learning with mobileV2 based on the data generator created in this tutorial ?
I followed this one step by step :
https://erdem.pl/2019/12/how-to-create-tensorflow-2-sequence-dataset-from-scratch

@burnpiro
Copy link
Author

Yes, image_name should be updated at the end of each loop. The updated code should work fine right now. I don't know any more articles but if you want to get a working example (hopefully TF didn't change the API) you can check the same code but with the rest of the network.
https://github.com/burnpiro/tiny-face-detection-tensorflow2
I'm working with a similar structure and using a siamese network for image embeddings right now:
https://github.com/burnpiro/farm-animal-tracking

It might take you a while to start this project because the dataset has to be generated from the original dataset but if you don't want to train this and just use it you can get the weights from Firestore and run the siamese network. Just remember that network is created to be an embedder, not a detector so it just parsing images into 2048 dim space at the end. It's useful for our purposes (see embedding visualization) but it's not exactly an object detection (for that we're using object_detection lib and resnet52).

@felattaoui
Copy link

Thank you very much for your help !

@felattaoui
Copy link

Hello,
I hope you are doing good.
I downloaded this app that you developed (https://github.com/burnpiro/tiny-face-detection-tensorflow2), the data and so on...
I trained a model without modifying the code and I have a very very poor val_iou value so that the algorithm is not able to detect faces.

Do you have some local changes that you did not push ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment