Created
June 28, 2021 12:16
-
-
Save metal3d/62368f7358af42420d0325b0a58c2a4c to your computer and use it in GitHub Desktop.
Script to train a binary model to help on creating dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Train model and find bad image from a directory """ | |
import argparse | |
import os | |
import time | |
from glob import glob | |
from typing import Any | |
import numpy as np | |
from tensorflow.python.keras.layers.pooling import MaxPool2D | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
from tensorflow import keras | |
def gen_model(input_dim=64) -> keras.Model: | |
""" Generate the model """ | |
return keras.Sequential( | |
[ | |
keras.layers.Conv2D( | |
64, | |
3, | |
1, | |
activation="relu", | |
input_shape=(input_dim, input_dim, 3), | |
padding="same", | |
name="input_image", | |
), | |
keras.layers.MaxPool2D(), | |
keras.layers.Conv2D(64, 3, 2, activation="relu"), | |
keras.layers.Conv2D(64, 3, 2, activation="relu"), | |
keras.layers.Flatten(), | |
keras.layers.Dense(1, activation="sigmoid", name="output_classification"), | |
] | |
) | |
def train(): | |
""" Train the model """ | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--class-dir", | |
help="Where you've created bad and good directories", | |
default="./data", | |
required=True, | |
) | |
parser.add_argument( | |
"--base-dir", | |
help="Where the classifier will check your images", | |
required=True, | |
) | |
parser.add_argument( | |
"--fix-dir", | |
help="Where bad images are copied (warning, this directory will be pruned)", | |
default="./fixed", | |
required=True, | |
) | |
parser.add_argument( | |
"--dim", | |
help="Model input dimension", | |
type=int, | |
default="64", | |
) | |
parser.add_argument( | |
"--epochs", | |
help="Number of epoch to train", | |
default=20, | |
type=int, | |
) | |
parser.add_argument( | |
"--trigger", | |
help="Minimum score for the target label", | |
default=0.5, | |
type=float, | |
) | |
parser.add_argument( | |
"--continue", | |
help="Continue to train the same model", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--no-train", | |
help='Do not train the model (has got sense if "--continue" is set)', | |
action="store_true", | |
) | |
parser.add_argument( | |
"--tensorboard", | |
help="Active tensorboard logs", | |
action="store_true", | |
) | |
args = parser.parse_args() | |
dim = args.dim | |
# get the data name | |
model_name = os.path.dirname(args.base_dir + os.path.sep).split(os.path.sep)[-1] | |
model_name += "_model" | |
# build/open the model | |
model: Any | |
# remove bad images from the data directory | |
bad_dir = os.path.join(args.class_dir, "bad", "*") | |
files = glob(bad_dir) | |
print(f"==> removing bad images in {args.base_dir}") | |
for fname in files: | |
fname = os.path.basename(fname) | |
fname = os.path.join(args.base_dir, fname) | |
os.system(f"rm -f {fname}") | |
if not vars(args)["continue"]: | |
print("==> Creating new model") | |
model = gen_model(dim) | |
else: | |
try: | |
print("==> Trying to open previous model") | |
model = keras.models.load_model(model_name) | |
except Exception: # pylint: disable=bare-except,broad-except | |
print("==> No previous model found, creating one...") | |
model = gen_model(dim) | |
model.compile( | |
keras.optimizers.Adam(learning_rate=0.001), | |
loss=keras.losses.BinaryCrossentropy(), | |
metrics=keras.metrics.BinaryAccuracy(), | |
) | |
model.summary() | |
# data augmentation | |
gen = keras.preprocessing.image.ImageDataGenerator( | |
rescale=1.0 / 255, | |
# shear_range=0.2, | |
zoom_range=0.2, | |
horizontal_flip=True, | |
width_shift_range=0.2, | |
height_shift_range=0.2, | |
rotation_range=15, | |
fill_mode="nearest", | |
) | |
trainer = gen.flow_from_directory( | |
args.class_dir, | |
target_size=(dim, dim), | |
# color_mode="grayscale", | |
batch_size=4, | |
shuffle=True, | |
class_mode="binary", | |
) | |
callbacks = [] | |
if args.tensorboard: | |
callbacks.append(keras.callbacks.TensorBoard()) | |
stopped_by_user = False | |
if not args.no_train: | |
try: | |
# train ! | |
model.fit(trainer, epochs=args.epochs, callbacks=callbacks) | |
except KeyboardInterrupt: | |
# save the model | |
stopped_by_user = True | |
finally: | |
print() | |
print("===> Saving the model") | |
model.save(model_name) | |
if stopped_by_user: | |
print( | |
"====== Stopped by user, press CTRL+C again to completly stop the script =======" | |
) | |
end = 10 | |
while end >= 0: | |
print("\r" * 5, end="") | |
print(f"{end:02d}/10", end="") | |
time.sleep(1) | |
end -= 1 | |
print() | |
# for information | |
print(trainer.class_indices) | |
# now, get bad images from the given directory | |
os.makedirs(args.fix_dir, exist_ok=True) | |
# delete the "fixes" content | |
os.system(f"rm -rf {args.fix_dir}/*") | |
# get the "good" label index | |
limit = int(trainer.class_indices["good"]) | |
dir_to_check = args.base_dir | |
glob_pattern = os.path.join(dir_to_check, "*") | |
files = glob(glob_pattern) | |
for fname in files: | |
bname = os.path.basename(fname) | |
good_image = os.path.join(args.class_dir, "good", bname) | |
is_good = os.path.exists(good_image) | |
if is_good: | |
# image is already classified as good | |
continue | |
image = keras.preprocessing.image.load_img( | |
fname, | |
target_size=(dim, dim), | |
) | |
image = keras.preprocessing.image.img_to_array(image) | |
image = np.expand_dims(image, axis=0) | |
res = model.predict(image) | |
score = res | |
res = res[0][0] | |
res = np.abs(limit - res) | |
# good = res[0][limit] | |
if res > args.trigger: | |
print(fname, "is bad", res, score) | |
os.system(f"cp {fname} {args.fix_dir}") # copy to fixes | |
os.system(f"rm -f {fname}") # remove it from base directory | |
print(f"==> copying back good images in {args.base_dir}") | |
os.system(f"cp {args.class_dir}/good/* {args.base_dir}/") | |
if __name__ == "__main__": | |
train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment