Last active
November 2, 2022 07:17
-
-
Save hizkifw/eb1f401a4181dccb09602a3edb15343f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
clipbooru.py | |
Go through a list of (md5).(ext), downloads them from danbooru, and compute CLIP | |
embeddings. Images are downloaded in parallel while inference is being done, so | |
this should run as fast as your hardware/network can handle. | |
""" | |
# requirements.txt | |
""" | |
--extra-index-url https://download.pytorch.org/whl/cu113 | |
torch==1.12.1+cu113 | |
torchvision==0.13.1+cu113 | |
numpy==1.23.4 | |
pandas==1.5.1 | |
Pillow==9.2.0 | |
tqdm==4.64.1 | |
transformers==4.23.1 | |
requests==2.28.1 | |
""" | |
# 64 fits in my 4GB 3050Ti | |
batch_size = 64 | |
parallel_downloads = 16 | |
out_dir = "./features" | |
img_list_file = "C:\\Users\\hizki\workspace\\dandan\\files.txt" | |
# Fix bug | |
from PIL import ImageFile | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
from PIL import Image | |
from tqdm import tqdm | |
import requests | |
import numpy as np | |
import torch | |
import os | |
import threading | |
import queue | |
from multiprocessing.pool import ThreadPool | |
from transformers import CLIPFeatureExtractor, CLIPModel | |
from typing import Tuple | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print("Using device:", device) | |
print("Loading CLIP") | |
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device) | |
extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14") | |
if not os.path.exists(out_dir): | |
os.mkdir(out_dir) | |
imgq = queue.Queue[Tuple[str, Image.Image]](2 * batch_size) | |
dlq = queue.Queue[Tuple[str, str]](2 * batch_size) | |
def dlthread(): | |
""" | |
Download each image in dlq and put them into imgq | |
""" | |
while True: | |
try: | |
ident, url = dlq.get() | |
img = Image.open(requests.get(url, stream=True).raw) | |
imgq.put((ident, img)) | |
except KeyboardInterrupt: | |
raise | |
except Exception as ex: | |
print("Error in dlthread:", ex) | |
def queuethread(): | |
""" | |
Read each line in the folder and queues valid images into the dlq | |
""" | |
with open(img_list_file) as f: | |
for line in tqdm(f, desc="Download"): | |
try: | |
file = line.strip() | |
url = f"https://cdn.donmai.us/original/{file[0:2]}/{file[2:4]}/{file}" | |
# Filter to only images | |
if not file.endswith(".jpg") and not file.endswith(".png"): | |
continue | |
# Skip existing files | |
if os.path.exists(os.path.join(out_dir, f"{file}.npy")): | |
continue | |
# Queue the download | |
dlq.put((file, url)) | |
except KeyboardInterrupt: | |
return | |
except Exception as e: | |
print("Error in queuethread:", e) | |
continue | |
# Clear the batch | |
batch = [] | |
# Start the queue thread | |
t_queue = threading.Thread(target=queuethread, daemon=True) | |
t_queue.start() | |
# Start the download threads | |
for i in range(parallel_downloads): | |
t = threading.Thread(target=dlthread, daemon=True) | |
t.start() | |
# Process the queue | |
batch = [] | |
pbar = tqdm(desc="Process") | |
with torch.no_grad(): | |
while True: | |
try: | |
ident, img = imgq.get() | |
batch.append((ident, img)) | |
if len(batch) < batch_size: | |
continue | |
except KeyboardInterrupt: | |
break | |
except Exception as e: | |
print("Failed to get image from queue:", e) | |
continue | |
# Process the batch | |
try: | |
extracted = extractor( | |
images=[image for (_ident, image) in batch], return_tensors="pt" | |
).to(device) | |
except KeyboardInterrupt: | |
break | |
except Exception as e: | |
print("Failed to extract batch:", e) | |
batch = [] | |
continue | |
try: | |
features = model.get_image_features(**extracted).cpu().numpy() | |
for n, result in enumerate(batch): | |
feature = features[n] | |
np.save(os.path.join(out_dir, f"{result[0]}.npy"), feature) | |
pbar.update(1) | |
except KeyboardInterrupt: | |
break | |
except Exception as e: | |
print("Failed to process batch:", e) | |
batch = [] | |
continue | |
# Clear the batch | |
batch = [] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment