Skip to content

Instantly share code, notes, and snippets.

@csiki
Last active March 2, 2023 02:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save csiki/1925f6a16a7000a3e47cb320fbb90c52 to your computer and use it in GitHub Desktop.
Save csiki/1925f6a16a7000a3e47cb320fbb90c52 to your computer and use it in GitHub Desktop.
Rewrite of rom1504/clip-retrieval/clip_retrieval/clip_filter.py to handle webdataset filtering locally or on s3
"""clip filter is a tool to use a knn index and a image/text collection to extract interesting subsets"""
# updated it for handling webdataset tar files
import fire
import boto3
import os, sys
from botocore.exceptions import ClientError
def break_up_s3_path(path):
s3_path = path[len('s3://'):]
bucket_name = s3_path[:s3_path.find('/')]
file_key = s3_path[s3_path.find('/') + 1:]
file_key = file_key[:-1] if file_key.endswith('/') else file_key
return bucket_name, file_key
def download_file_from_s3(s3client, bucket: str, file_name: str, to_file: str, mode='wb'):
chunk_size = 1024 ** 3
in_file = s3client.get_object(Bucket=bucket, Key=file_name)['Body']
with open(to_file, mode) as out_file:
while True:
chunk = in_file.read(chunk_size)
if chunk == b'':
break
out_file.write(chunk)
return to_file
def upload_file_to_s3(s3client, file_name, bucket, object_name=None):
# If S3 object_name was not specified, use file_name
if object_name is None:
object_name = os.path.basename(file_name)
# Upload the file
try:
response = s3client.upload_file(file_name, bucket, object_name)
except ClientError as e:
print(f'could not upload file {file_name} to: s3://{bucket}/{object_name}\n{e}', file=sys.stderr)
return False
return True
def clip_filter_webdataset(query, input_folder, output_folder, indice_folder, num_results=1000, threshold=None,
num_filt_imgs_per_tar=1000):
"""Entry point of clip filter"""
import faiss # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
import os # pylint: disable=import-outside-toplevel
import shutil # pylint: disable=import-outside-toplevel
from pathlib import Path # pylint: disable=import-outside-toplevel
import pandas as pd # pylint: disable=import-outside-toplevel
import clip # pylint: disable=import-outside-toplevel
from PIL import Image # pylint: disable=import-outside-toplevel
import numpy as np
import webdataset as wds
import tarfile
import sys
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
data_dir = Path(indice_folder + "/metadata")
df = pd.concat(pd.read_parquet(parquet_file) for parquet_file in sorted(data_dir.glob("*.parquet")))
url_list = None
if "url" in df:
url_list = df["url"].tolist()
image_list = df["image_path"].tolist()
image_index = faiss.read_index(indice_folder + "/image.index")
indices_loaded = {
"image_list": image_list,
"image_index": image_index,
}
image_index = indices_loaded["image_index"]
image_list = indices_loaded["image_list"]
if query.endswith((".png", ".jpg", ".jpeg", ".bmp")) and os.path.isfile(query):
im = Image.open(query)
query_features = model.encode_image(preprocess(im).unsqueeze(0).to(device))
else:
text = clip.tokenize([query]).to(device)
query_features = model.encode_text(text)
query_features /= query_features.norm(dim=-1, keepdim=True)
query_features = query_features.cpu().detach().numpy().astype("float32")
index = image_index
if threshold is not None:
_, d, i = index.range_search(query_features, threshold)
print(f"Found {i.shape} items with query '{query}' and threshold {threshold}")
else:
d, i = index.search(query_features, num_results)
print(f"Found {num_results} items with query '{query}'")
i = i[0]
d = d[0]
min_d = min(d)
max_d = max(d)
print(f"The minimum distance is {min_d:.2f} and the maximum is {max_d:.2f}")
print("You may want to use these numbers to increase your --num_results parameter. "
"Or use the --threshold parameter.")
img_ids = [image_list[ei] for _, ei in zip(d, i)]
# order paths, group them by corresponding tar files,
# assuming that each tar has 5 digits identifiers
img_ids = sorted(img_ids)
tar_ids = np.unique([iid[:5] for iid in img_ids])
img_by_tar = {tid: set() for tid in tar_ids}
for iid in img_ids:
img_by_tar[iid[:5]].add(iid)
# map old tar files to new wds tar files
filt_tid = 0
img_count = 0
filt_to_src_tar = {}
for tid, iids in img_by_tar.items():
ftid = f'{filt_tid:05d}'
if ftid not in filt_to_src_tar:
filt_to_src_tar[ftid] = []
filt_to_src_tar[ftid].append(tid)
img_count += len(iids)
if img_count > num_filt_imgs_per_tar:
filt_tid += 1
img_count = 0
# iterate over tars to load, copy them from s3 to local if needed, take stuff out of them,
# write stuff to local tar file, and upload to s3 if needed
s3_input = input_folder.startswith('s3://')
s3_output = output_folder.startswith('s3://')
s3client = boto3.client('s3') if s3_input or s3_output else None
total_items_written = 0
if not s3_output:
os.makedirs(output_folder, exist_ok=True)
filtered_tars = []
for ftid, stids in filt_to_src_tar.items():
local_dst = f'_tmp_{ftid}.tar' if s3_output else f'{output_folder}/{ftid}.tar'
dst_items_written = 0
try:
with wds.TarWriter(local_dst) as dst:
for src_tid in stids:
src_items_written = 0
src_path = f'{input_folder}/{src_tid}.tar'
# if input_dir is on s3, first copy file to local
if s3_input:
bucket_name, file_key = break_up_s3_path(src_path)
local_src = download_file_from_s3(s3client, bucket_name, file_key, f'_tmp_{src_tid}.tar')
else: # local
local_src = src_path
# take items from src tar
src = wds.WebDataset(local_src, cache_size=10 ** 10)
for item in src: # key order is not guaranteed
if item['__key__'] in img_by_tar[src_tid]:
dst.write(item)
src_items_written += 1
print(f'|-- items written from {src_path}: {src_items_written}/{len(img_by_tar[src_tid])}')
dst_items_written += src_items_written
# delete local src
if s3_input:
os.remove(local_src)
except tarfile.ReadError:
print('Unable to read tar file at:', local_src, file=sys.stderr)
print(f'--> items written to {local_dst}: {dst_items_written}')
total_items_written += dst_items_written
# upload filtered tar file to s3
dst_path = f'{output_folder}/{ftid}.tar'
if s3_output and dst_items_written > 0:
print('|----- uploading to:', os.path.basename(dst_path), flush=True)
bucket_name, file_key = break_up_s3_path(dst_path)
upload_success = upload_file_to_s3(s3client, local_dst, bucket_name, file_key)
if not upload_success:
raise RuntimeError(f'Could not upload: {local_dst} -> {dst_path}')
print('---->| uploading done to:', os.path.basename(dst_path), flush=True)
# delete local filtered tar file
try:
os.remove(local_dst)
except OSError:
pass
else: # local, not s3, output
if dst_items_written == 0:
os.remove(local_dst)
filtered_tars.append(dst_path)
print('total items written:', total_items_written)
print('filtered tar files:', filtered_tars)
return filtered_tars
if __name__ == "__main__":
fire.Fire(clip_filter_webdataset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment