Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@metal3d
Created June 28, 2021 12:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save metal3d/62368f7358af42420d0325b0a58c2a4c to your computer and use it in GitHub Desktop.
Save metal3d/62368f7358af42420d0325b0a58c2a4c to your computer and use it in GitHub Desktop.
Script to train a binary model to help on creating dataset
""" Train model and find bad image from a directory """
import argparse
import os
import time
from glob import glob
from typing import Any
import numpy as np
from tensorflow.python.keras.layers.pooling import MaxPool2D
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from tensorflow import keras
def gen_model(input_dim=64) -> keras.Model:
""" Generate the model """
return keras.Sequential(
[
keras.layers.Conv2D(
64,
3,
1,
activation="relu",
input_shape=(input_dim, input_dim, 3),
padding="same",
name="input_image",
),
keras.layers.MaxPool2D(),
keras.layers.Conv2D(64, 3, 2, activation="relu"),
keras.layers.Conv2D(64, 3, 2, activation="relu"),
keras.layers.Flatten(),
keras.layers.Dense(1, activation="sigmoid", name="output_classification"),
]
)
def train():
""" Train the model """
parser = argparse.ArgumentParser()
parser.add_argument(
"--class-dir",
help="Where you've created bad and good directories",
default="./data",
required=True,
)
parser.add_argument(
"--base-dir",
help="Where the classifier will check your images",
required=True,
)
parser.add_argument(
"--fix-dir",
help="Where bad images are copied (warning, this directory will be pruned)",
default="./fixed",
required=True,
)
parser.add_argument(
"--dim",
help="Model input dimension",
type=int,
default="64",
)
parser.add_argument(
"--epochs",
help="Number of epoch to train",
default=20,
type=int,
)
parser.add_argument(
"--trigger",
help="Minimum score for the target label",
default=0.5,
type=float,
)
parser.add_argument(
"--continue",
help="Continue to train the same model",
action="store_true",
)
parser.add_argument(
"--no-train",
help='Do not train the model (has got sense if "--continue" is set)',
action="store_true",
)
parser.add_argument(
"--tensorboard",
help="Active tensorboard logs",
action="store_true",
)
args = parser.parse_args()
dim = args.dim
# get the data name
model_name = os.path.dirname(args.base_dir + os.path.sep).split(os.path.sep)[-1]
model_name += "_model"
# build/open the model
model: Any
# remove bad images from the data directory
bad_dir = os.path.join(args.class_dir, "bad", "*")
files = glob(bad_dir)
print(f"==> removing bad images in {args.base_dir}")
for fname in files:
fname = os.path.basename(fname)
fname = os.path.join(args.base_dir, fname)
os.system(f"rm -f {fname}")
if not vars(args)["continue"]:
print("==> Creating new model")
model = gen_model(dim)
else:
try:
print("==> Trying to open previous model")
model = keras.models.load_model(model_name)
except Exception: # pylint: disable=bare-except,broad-except
print("==> No previous model found, creating one...")
model = gen_model(dim)
model.compile(
keras.optimizers.Adam(learning_rate=0.001),
loss=keras.losses.BinaryCrossentropy(),
metrics=keras.metrics.BinaryAccuracy(),
)
model.summary()
# data augmentation
gen = keras.preprocessing.image.ImageDataGenerator(
rescale=1.0 / 255,
# shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
width_shift_range=0.2,
height_shift_range=0.2,
rotation_range=15,
fill_mode="nearest",
)
trainer = gen.flow_from_directory(
args.class_dir,
target_size=(dim, dim),
# color_mode="grayscale",
batch_size=4,
shuffle=True,
class_mode="binary",
)
callbacks = []
if args.tensorboard:
callbacks.append(keras.callbacks.TensorBoard())
stopped_by_user = False
if not args.no_train:
try:
# train !
model.fit(trainer, epochs=args.epochs, callbacks=callbacks)
except KeyboardInterrupt:
# save the model
stopped_by_user = True
finally:
print()
print("===> Saving the model")
model.save(model_name)
if stopped_by_user:
print(
"====== Stopped by user, press CTRL+C again to completly stop the script ======="
)
end = 10
while end >= 0:
print("\r" * 5, end="")
print(f"{end:02d}/10", end="")
time.sleep(1)
end -= 1
print()
# for information
print(trainer.class_indices)
# now, get bad images from the given directory
os.makedirs(args.fix_dir, exist_ok=True)
# delete the "fixes" content
os.system(f"rm -rf {args.fix_dir}/*")
# get the "good" label index
limit = int(trainer.class_indices["good"])
dir_to_check = args.base_dir
glob_pattern = os.path.join(dir_to_check, "*")
files = glob(glob_pattern)
for fname in files:
bname = os.path.basename(fname)
good_image = os.path.join(args.class_dir, "good", bname)
is_good = os.path.exists(good_image)
if is_good:
# image is already classified as good
continue
image = keras.preprocessing.image.load_img(
fname,
target_size=(dim, dim),
)
image = keras.preprocessing.image.img_to_array(image)
image = np.expand_dims(image, axis=0)
res = model.predict(image)
score = res
res = res[0][0]
res = np.abs(limit - res)
# good = res[0][limit]
if res > args.trigger:
print(fname, "is bad", res, score)
os.system(f"cp {fname} {args.fix_dir}") # copy to fixes
os.system(f"rm -f {fname}") # remove it from base directory
print(f"==> copying back good images in {args.base_dir}")
os.system(f"cp {args.class_dir}/good/* {args.base_dir}/")
if __name__ == "__main__":
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment