Skip to content

Instantly share code, notes, and snippets.

@hizkifw
Last active November 2, 2022 07:17
Show Gist options
  • Save hizkifw/eb1f401a4181dccb09602a3edb15343f to your computer and use it in GitHub Desktop.
Save hizkifw/eb1f401a4181dccb09602a3edb15343f to your computer and use it in GitHub Desktop.
"""
clipbooru.py
Go through a list of (md5).(ext), downloads them from danbooru, and compute CLIP
embeddings. Images are downloaded in parallel while inference is being done, so
this should run as fast as your hardware/network can handle.
"""
# requirements.txt
"""
--extra-index-url https://download.pytorch.org/whl/cu113
torch==1.12.1+cu113
torchvision==0.13.1+cu113
numpy==1.23.4
pandas==1.5.1
Pillow==9.2.0
tqdm==4.64.1
transformers==4.23.1
requests==2.28.1
"""
# 64 fits in my 4GB 3050Ti
batch_size = 64
parallel_downloads = 16
out_dir = "./features"
img_list_file = "C:\\Users\\hizki\workspace\\dandan\\files.txt"
# Fix bug
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from PIL import Image
from tqdm import tqdm
import requests
import numpy as np
import torch
import os
import threading
import queue
from multiprocessing.pool import ThreadPool
from transformers import CLIPFeatureExtractor, CLIPModel
from typing import Tuple
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print("Loading CLIP")
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14")
if not os.path.exists(out_dir):
os.mkdir(out_dir)
imgq = queue.Queue[Tuple[str, Image.Image]](2 * batch_size)
dlq = queue.Queue[Tuple[str, str]](2 * batch_size)
def dlthread():
"""
Download each image in dlq and put them into imgq
"""
while True:
try:
ident, url = dlq.get()
img = Image.open(requests.get(url, stream=True).raw)
imgq.put((ident, img))
except KeyboardInterrupt:
raise
except Exception as ex:
print("Error in dlthread:", ex)
def queuethread():
"""
Read each line in the folder and queues valid images into the dlq
"""
with open(img_list_file) as f:
for line in tqdm(f, desc="Download"):
try:
file = line.strip()
url = f"https://cdn.donmai.us/original/{file[0:2]}/{file[2:4]}/{file}"
# Filter to only images
if not file.endswith(".jpg") and not file.endswith(".png"):
continue
# Skip existing files
if os.path.exists(os.path.join(out_dir, f"{file}.npy")):
continue
# Queue the download
dlq.put((file, url))
except KeyboardInterrupt:
return
except Exception as e:
print("Error in queuethread:", e)
continue
# Clear the batch
batch = []
# Start the queue thread
t_queue = threading.Thread(target=queuethread, daemon=True)
t_queue.start()
# Start the download threads
for i in range(parallel_downloads):
t = threading.Thread(target=dlthread, daemon=True)
t.start()
# Process the queue
batch = []
pbar = tqdm(desc="Process")
with torch.no_grad():
while True:
try:
ident, img = imgq.get()
batch.append((ident, img))
if len(batch) < batch_size:
continue
except KeyboardInterrupt:
break
except Exception as e:
print("Failed to get image from queue:", e)
continue
# Process the batch
try:
extracted = extractor(
images=[image for (_ident, image) in batch], return_tensors="pt"
).to(device)
except KeyboardInterrupt:
break
except Exception as e:
print("Failed to extract batch:", e)
batch = []
continue
try:
features = model.get_image_features(**extracted).cpu().numpy()
for n, result in enumerate(batch):
feature = features[n]
np.save(os.path.join(out_dir, f"{result[0]}.npy"), feature)
pbar.update(1)
except KeyboardInterrupt:
break
except Exception as e:
print("Failed to process batch:", e)
batch = []
continue
# Clear the batch
batch = []
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment