Skip to content

Instantly share code, notes, and snippets.

@tori29umai0123
Last active June 5, 2024 01:47
Show Gist options
  • Save tori29umai0123/a0776b5c40d43b901f51aa1b17957d48 to your computer and use it in GitHub Desktop.
Save tori29umai0123/a0776b5c40d43b901f51aa1b17957d48 to your computer and use it in GitHub Desktop.
nsfw_filter_with_tagger.py
import csv
import glob
import os
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
import onnx
import onnxruntime as ort
from huggingface_hub import hf_hub_download
import shutil
# Image size
IMAGE_SIZE = 448
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # Convert BGR to RGB
# Padding the image to make it square
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
def check_if_exists(image_path, directories):
filename = os.path.basename(image_path)
for directory in directories:
if os.path.exists(os.path.join(directory, filename)):
return True
return False
def run_batch(path_imgs, input_name, ort_sess, rating_tags, general_tags, thresh, nsfw_dir, sfw_dir):
imgs = np.array([im for _, im in path_imgs])
probs = ort_sess.run(None, {input_name: imgs})[0] # ONNX output
probs = probs[: len(path_imgs)]
for (image_path, _), prob in zip(path_imgs, probs):
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
max_sfw_score = tag_confidences.get("general", 0)
destination = nsfw_dir if max_nsfw_score > max_sfw_score else sfw_dir
tag_file_path = os.path.join(destination, os.path.splitext(os.path.basename(image_path))[0] + ".txt")
# Save tags in a single line
tag_list = [tag for i, tag in enumerate(general_tags) if prob[i] >= thresh]
with open(tag_file_path, 'w') as f:
f.write(", ".join(tag_list))
# Copy image to the appropriate folder
try:
shutil.copy(image_path, os.path.join(destination, os.path.basename(image_path)))
print(f"{image_path} copied to {destination}.")
except Exception as e:
print(f"Failed to copy {image_path} to {destination}. Error: {e}")
def main():
print("Loading wd14 tagger from Hugging Face")
onnx_path = hf_hub_download(MODEL_ID, "model.onnx")
csv_path = hf_hub_download(MODEL_ID, "selected_tags.csv")
print("Running wd14 tagger ONNX")
print(f"Loading ONNX model: {onnx_path}")
ort_sess = ort.InferenceSession(onnx_path)
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
header = next(reader) # Read header row
rows = list(reader)
assert header == ["tag_id", "name", "category", "count"], f"Unexpected CSV format: {header}"
rating_tags = [row[1] for row in rows if row[2] == "9"]
general_tags = [row[1] for row in rows if row[2] == "0"]
image_paths = glob.glob(os.path.join(input_dir, "*.*"))
b_imgs = []
for image_path in tqdm(image_paths, smoothing=0.0):
if not check_if_exists(image_path, [sfw_dir, nsfw_dir]):
try:
image = Image.open(image_path)
image = image.convert("RGB") if image.mode != "RGB" else image
image = preprocess_image(image)
b_imgs.append((image_path, image))
except Exception as e:
print(f"Failed to load image: {image_path}, Error: {e}")
continue
if len(b_imgs) >= batch_size:
run_batch(b_imgs, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, general_tags, thresh, nsfw_dir, sfw_dir)
b_imgs = []
if b_imgs:
run_batch(b_imgs, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, general_tags, thresh, nsfw_dir, sfw_dir)
print("Processing complete!")
if __name__ == "__main__":
MODEL_ID = "SmilingWolf/wd-swinv2-tagger-v3"
input_dir = "E:/desktop/dart"
sfw_dir = "E:/desktop/sfw"
nsfw_dir = "E:/desktop/nsfw"
if not os.path.exists(sfw_dir):
os.makedirs(sfw_dir)
if not os.path.exists(nsfw_dir):
os.makedirs(nsfw_dir)
batch_size = 16
thresh = 0.35
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment