Skip to content

Instantly share code, notes, and snippets.

@burnpiro burnpiro/a_train.py
Last active Dec 29, 2019

Embed
What would you like to do?
Tensorflow 2 custom dataset Sequence
import tensorflow as tf
from data.data_generator import DataGenerator
from config import cfg
## Create train dataset
train_datagen = DataGenerator(file_path=cfg.TRAIN.DATA_PATH, config_path=cfg.TRAIN.ANNOTATION_PATH)
## Create validation dataset
val_generator = DataGenerator(file_path=cfg.TEST.DATA_PATH, config_path=cfg.TEST.ANNOTATION_PATH, debug=False)
model.fit_generator(generator=train_datagen,
epochs=cfg.TRAIN.EPOCHS,
callbacks=[# your callbacks for TF],
shuffle=True,
verbose=1)
from easydict import EasyDict
__C = EasyDict()
cfg = __C
# create NN dict
__C.NN = EasyDict()
__C.NN.INPUT_SIZE = 224
# create Train options dict
__C.TRAIN = EasyDict()
__C.TRAIN.DATA_PATH = "./data/WIDER_train/images/"
__C.TRAIN.ANNOTATION_PATH = "./data/wider_face_split/wider_face_train_bbx_gt.txt"
__C.TRAIN.BATCH_SIZE = 16
# create VAL options dict
__C.VAL = EasyDict()
__C.VAL.DATA_PATH = "./data/WIDER_val/images/"
__C.VAL.ANNOTATION_PATH = "./data/wider_face_split/wider_face_val_bbx_gt.txt"
import os
import sys
import math
import numpy as np
import tensorflow as tf
from config import cfg
# Input: [x0, y0, w, h, blur, expression, illumination, invalid, occlusion, pose]
# Output: x0, y0, w, h
def get_box(data):
x0 = int(data[0])
y0 = int(data[1])
w = int(data[2])
h = int(data[3])
return x0, y0, w, h
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, file_path, config_path, debug=False):
self.boxes = []
self.debug = debug
self.data_path = file_path
if not os.path.isfile(config_path):
print("File path {} does not exist. Exiting...".format(config_path))
sys.exit()
if not os.path.isdir(file_path):
print("Images folder path {} does not exist. Exiting...".format(file_path))
sys.exit()
with open(config_path) as fp:
image_name = fp.readline()
cnt = 1
while image_name:
num_of_obj = int(fp.readline())
for i in range(num_of_obj):
obj_box = fp.readline().split(' ')
x0, y0, w, h = get_box(obj_box)
if w == 0:
# remove boxes with no width
continue
if h == 0:
# remove boxes with no height
continue
self.boxes.append((line.strip(), x0, y0, w, h))
if num_of_obj == 0:
obj_box = fp.readline().split(' ')
x0, y0, w, h = get_box(obj_box)
self.boxes.append((line.strip(), x0, y0, w, h))
line = fp.readline()
cnt += 1
def __len__(self):
return math.ceil(len(self.boxes) / cfg.TRAIN.BATCH_SIZE)
def __getitem__(self, idx):
boxes = self.boxes[idx * cfg.TRAIN.BATCH_SIZE:(idx + 1) * cfg.TRAIN.BATCH_SIZE]
batch_images = np.zeros((len(boxes), cfg.NN.INPUT_SIZE, cfg.NN.INPUT_SIZE, 3), dtype=np.float32)
batch_boxes = np.zeros((len(boxes), cfg.NN.GRID_SIZE, cfg.NN.GRID_SIZE, 5), dtype=np.float32)
for i, row in enumerate(boxes):
path, x0, y0, w, h = row
proc_image = tf.keras.preprocessing.image.load_img(self.data_path + path)
image_width = proc_image.width
image_height = proc_image.height
proc_image = tf.keras.preprocessing.image.load_img(self.data_path + path,
target_size=(cfg.NN.INPUT_SIZE, cfg.NN.INPUT_SIZE))
proc_image = tf.keras.preprocessing.image.img_to_array(proc_image)
proc_image = np.expand_dims(proc_image, axis=0)
proc_image - tf.keras.applications.mobilenet_v2.preprocess_input(proc_image)
batch_images[i] = proc_image
# make sure none of the points is out of image border
x0 = max(x0, 0)
y0 = max(y0, 0)
x0 = min(x0, image_width)
y0 = min(y0, image_height)
x_c = (cfg.NN.GRID_SIZE / image_width) * x0
y_c = (cfg.NN.GRID_SIZE / image_height) * y0
floor_y = math.floor(y_c) # handle case when x i on the corner
floor_x = math.floor(x_c) # handle case when y i on the corner
batch_boxes[i, floor_y, floor_x, 0] = h / image_height
batch_boxes[i, floor_y, floor_x, 1] = w / image_width
batch_boxes[i, floor_y, floor_x, 2] = y_c - floor_y
batch_boxes[i, floor_y, floor_x, 3] = x_c - floor_x
batch_boxes[i, floor_y, floor_x, 4] = 1
return batch_images, batch_boxes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.