Last active
March 2, 2023 02:00
-
-
Save csiki/1925f6a16a7000a3e47cb320fbb90c52 to your computer and use it in GitHub Desktop.
Rewrite of rom1504/clip-retrieval/clip_retrieval/clip_filter.py to handle webdataset filtering locally or on s3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""clip filter is a tool to use a knn index and a image/text collection to extract interesting subsets""" | |
# updated it for handling webdataset tar files | |
import fire | |
import boto3 | |
import os, sys | |
from botocore.exceptions import ClientError | |
def break_up_s3_path(path): | |
s3_path = path[len('s3://'):] | |
bucket_name = s3_path[:s3_path.find('/')] | |
file_key = s3_path[s3_path.find('/') + 1:] | |
file_key = file_key[:-1] if file_key.endswith('/') else file_key | |
return bucket_name, file_key | |
def download_file_from_s3(s3client, bucket: str, file_name: str, to_file: str, mode='wb'): | |
chunk_size = 1024 ** 3 | |
in_file = s3client.get_object(Bucket=bucket, Key=file_name)['Body'] | |
with open(to_file, mode) as out_file: | |
while True: | |
chunk = in_file.read(chunk_size) | |
if chunk == b'': | |
break | |
out_file.write(chunk) | |
return to_file | |
def upload_file_to_s3(s3client, file_name, bucket, object_name=None): | |
# If S3 object_name was not specified, use file_name | |
if object_name is None: | |
object_name = os.path.basename(file_name) | |
# Upload the file | |
try: | |
response = s3client.upload_file(file_name, bucket, object_name) | |
except ClientError as e: | |
print(f'could not upload file {file_name} to: s3://{bucket}/{object_name}\n{e}', file=sys.stderr) | |
return False | |
return True | |
def clip_filter_webdataset(query, input_folder, output_folder, indice_folder, num_results=1000, threshold=None, | |
num_filt_imgs_per_tar=1000): | |
"""Entry point of clip filter""" | |
import faiss # pylint: disable=import-outside-toplevel | |
import torch # pylint: disable=import-outside-toplevel | |
import os # pylint: disable=import-outside-toplevel | |
import shutil # pylint: disable=import-outside-toplevel | |
from pathlib import Path # pylint: disable=import-outside-toplevel | |
import pandas as pd # pylint: disable=import-outside-toplevel | |
import clip # pylint: disable=import-outside-toplevel | |
from PIL import Image # pylint: disable=import-outside-toplevel | |
import numpy as np | |
import webdataset as wds | |
import tarfile | |
import sys | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load("ViT-B/32", device=device, jit=False) | |
data_dir = Path(indice_folder + "/metadata") | |
df = pd.concat(pd.read_parquet(parquet_file) for parquet_file in sorted(data_dir.glob("*.parquet"))) | |
url_list = None | |
if "url" in df: | |
url_list = df["url"].tolist() | |
image_list = df["image_path"].tolist() | |
image_index = faiss.read_index(indice_folder + "/image.index") | |
indices_loaded = { | |
"image_list": image_list, | |
"image_index": image_index, | |
} | |
image_index = indices_loaded["image_index"] | |
image_list = indices_loaded["image_list"] | |
if query.endswith((".png", ".jpg", ".jpeg", ".bmp")) and os.path.isfile(query): | |
im = Image.open(query) | |
query_features = model.encode_image(preprocess(im).unsqueeze(0).to(device)) | |
else: | |
text = clip.tokenize([query]).to(device) | |
query_features = model.encode_text(text) | |
query_features /= query_features.norm(dim=-1, keepdim=True) | |
query_features = query_features.cpu().detach().numpy().astype("float32") | |
index = image_index | |
if threshold is not None: | |
_, d, i = index.range_search(query_features, threshold) | |
print(f"Found {i.shape} items with query '{query}' and threshold {threshold}") | |
else: | |
d, i = index.search(query_features, num_results) | |
print(f"Found {num_results} items with query '{query}'") | |
i = i[0] | |
d = d[0] | |
min_d = min(d) | |
max_d = max(d) | |
print(f"The minimum distance is {min_d:.2f} and the maximum is {max_d:.2f}") | |
print("You may want to use these numbers to increase your --num_results parameter. " | |
"Or use the --threshold parameter.") | |
img_ids = [image_list[ei] for _, ei in zip(d, i)] | |
# order paths, group them by corresponding tar files, | |
# assuming that each tar has 5 digits identifiers | |
img_ids = sorted(img_ids) | |
tar_ids = np.unique([iid[:5] for iid in img_ids]) | |
img_by_tar = {tid: set() for tid in tar_ids} | |
for iid in img_ids: | |
img_by_tar[iid[:5]].add(iid) | |
# map old tar files to new wds tar files | |
filt_tid = 0 | |
img_count = 0 | |
filt_to_src_tar = {} | |
for tid, iids in img_by_tar.items(): | |
ftid = f'{filt_tid:05d}' | |
if ftid not in filt_to_src_tar: | |
filt_to_src_tar[ftid] = [] | |
filt_to_src_tar[ftid].append(tid) | |
img_count += len(iids) | |
if img_count > num_filt_imgs_per_tar: | |
filt_tid += 1 | |
img_count = 0 | |
# iterate over tars to load, copy them from s3 to local if needed, take stuff out of them, | |
# write stuff to local tar file, and upload to s3 if needed | |
s3_input = input_folder.startswith('s3://') | |
s3_output = output_folder.startswith('s3://') | |
s3client = boto3.client('s3') if s3_input or s3_output else None | |
total_items_written = 0 | |
if not s3_output: | |
os.makedirs(output_folder, exist_ok=True) | |
filtered_tars = [] | |
for ftid, stids in filt_to_src_tar.items(): | |
local_dst = f'_tmp_{ftid}.tar' if s3_output else f'{output_folder}/{ftid}.tar' | |
dst_items_written = 0 | |
try: | |
with wds.TarWriter(local_dst) as dst: | |
for src_tid in stids: | |
src_items_written = 0 | |
src_path = f'{input_folder}/{src_tid}.tar' | |
# if input_dir is on s3, first copy file to local | |
if s3_input: | |
bucket_name, file_key = break_up_s3_path(src_path) | |
local_src = download_file_from_s3(s3client, bucket_name, file_key, f'_tmp_{src_tid}.tar') | |
else: # local | |
local_src = src_path | |
# take items from src tar | |
src = wds.WebDataset(local_src, cache_size=10 ** 10) | |
for item in src: # key order is not guaranteed | |
if item['__key__'] in img_by_tar[src_tid]: | |
dst.write(item) | |
src_items_written += 1 | |
print(f'|-- items written from {src_path}: {src_items_written}/{len(img_by_tar[src_tid])}') | |
dst_items_written += src_items_written | |
# delete local src | |
if s3_input: | |
os.remove(local_src) | |
except tarfile.ReadError: | |
print('Unable to read tar file at:', local_src, file=sys.stderr) | |
print(f'--> items written to {local_dst}: {dst_items_written}') | |
total_items_written += dst_items_written | |
# upload filtered tar file to s3 | |
dst_path = f'{output_folder}/{ftid}.tar' | |
if s3_output and dst_items_written > 0: | |
print('|----- uploading to:', os.path.basename(dst_path), flush=True) | |
bucket_name, file_key = break_up_s3_path(dst_path) | |
upload_success = upload_file_to_s3(s3client, local_dst, bucket_name, file_key) | |
if not upload_success: | |
raise RuntimeError(f'Could not upload: {local_dst} -> {dst_path}') | |
print('---->| uploading done to:', os.path.basename(dst_path), flush=True) | |
# delete local filtered tar file | |
try: | |
os.remove(local_dst) | |
except OSError: | |
pass | |
else: # local, not s3, output | |
if dst_items_written == 0: | |
os.remove(local_dst) | |
filtered_tars.append(dst_path) | |
print('total items written:', total_items_written) | |
print('filtered tar files:', filtered_tars) | |
return filtered_tars | |
if __name__ == "__main__": | |
fire.Fire(clip_filter_webdataset) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment