Skip to content

Instantly share code, notes, and snippets.

@ryanwebster90
Created December 27, 2022 00:55
Show Gist options
  • Save ryanwebster90/5a645aceb519e60499089b24813e5ec5 to your computer and use it in GitHub Desktop.
Save ryanwebster90/5a645aceb519e60499089b24813e5ec5 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
import fire
import glob
def abs_ind_to_feat_file(abs_ind, cum_sz, feat_files):
inds = np.argwhere(abs_ind - cum_sz >= 0)
last_ind = inds[-1].item()
ind_offset = cum_sz[last_ind]
local_ind = abs_ind - ind_offset
return feat_files[last_ind],last_ind,local_ind
def pil_to_torch(img, none_to_black = True):
if img is not None:
img = np.array(img)
# handle BW images
if len(img.shape) == 2:
img = torch.from_numpy(img).to(torch.float).unsqueeze(0).repeat(3,1,1).unsqueeze(0)
else:
img = torch.from_numpy(img).to(torch.float).permute((2,0,1)).unsqueeze(0)
else:
img = torch.zeros(1,3,256,256).to(torch.float)
return img
def feat_key_to_file_loc(feat_path):
if '24_26' in feat_path:
abs_path = '/home/ryan/dw3/hdd5tb/laion_24_26/'
elif '16_20' in feat_path:
abs_path = '/home/ryan/hdd1/laion-data-16-20/'
elif '27_31' in feat_path:
abs_path = '/home/ryan/dw3/hdd/laion_data_27_31/'
elif '1_4_1' in feat_path:
abs_path = '/home/ryan/dw0/laion400m-data-old/'
elif '21_23' in feat_path:
abs_path = '/home/ryan/dw3/hdd5tb/laion_21_23_data/'
elif '9_12' in feat_path:
abs_path = '/home/ryan/dw0/hdd1/laion400m-p9_12/'
elif '1_4_0' in feat_path:
abs_path = '/home/ryan/dw0/hdd1/laion_data_old_first/'
else:
print('ERROR, no file path found')
return abs_path
def vis_nns_from_nn_inds(nn_file, chunk_size=4, num_chunks=2, query_image_folder=None, query_wds_folder = None, query_captions = None):
from PIL import Image
import glob
import os
import torchvision
import math
import pandas as pd
import wds_utils
import glob
import pickle as pkl
# out_name = out_name.split('/')[-1][:-3]
# we need to go from abs ind > tarfile location
feat_files = sorted(glob.glob('../hdd14tb/vitb32_overlap_feats/img_emb/*.npy'))
cum_sz = [0]
for feat in feat_files:
cum_sz += [cum_sz[-1] + np.load(feat,mmap_mode='r').shape[0]]
cum_sz = np.array(cum_sz).astype('int')
pq_files = sorted(glob.glob('../hdd14tb/vitb32_overlap_feats/metadata/*.npy'))
feat_files_src = feat_files
# nn_inds = np.load(nn_file).astype('int32')
nn_inds = nn_file
if query_image_folder is not None:
query_img_files = sorted(glob.glob(f'{query_image_folder}/*.jpg'))
# should we also do the same for the captions?
all_imgs = []
row_size = nn_inds.shape[1]
# os.makedirs(f'vis/{out_name}/',exist_ok=True)
text_strs = ''
metadata_strs = ''
import wds_utils_v1
offset = 0
# convert all the abs_inds
# for first
for i in range(1):
# replace this with a regular Image read if the query is an image folder
text_strs = ' \n'
# all_imgs += [pil_to_torch(img)]
all_imgs = []
for k in range(12):
abs_ind = nn_inds[0,k]
print("abs ind = ",abs_ind)
# go from abs into to feat file
feat_file,feat_file_ind, local_ind = abs_ind_to_feat_file(abs_ind, cum_sz, feat_files_src)
# print
# nn_key,ind = abs_ind_to_feat_key_and_local_ind(abs_ind,ind_map)
file_path = feat_key_to_file_loc(feat_file)
# feat_file,ind = abs_ind_to_folder(ind_map,abs_ind)
# nn_key,metadata_loc,data_loc = get_feat_file_data_and_metadata(feat_file,feat_to_file_loc)
print(f"ff={feat_file}, ffi={feat_file_ind}, file_path={file_path}, local_ind={local_ind}")
# we also need a dict of d[feat_file] = metadata[feat_file]
# for now, just reload the metadata every time
nn_files = np.load(pq_files[feat_file_ind])
nn_file = nn_files[local_ind]
# nn_file = list(pd.read_parquet(pq_files[feat_file_ind])["image_path"])[local_ind]
# nn_file, last_ind,local_ind = abs_ind_to_pq_feat_file(abs_ind, cum_sz_md, md_dict)
metadata_strs += f"{i},{k+1} nn file {nn_file}, feat file {feat_file} \n"
tar_size = 10000
if 'laion_4_8' in file_path:
tar_size = 1000
print('file path and nn file',file_path, nn_file)
img,caption = wds_utils.retrieve_image(file_path, nn_file, tar_size=tar_size, verb=True)
all_imgs += [pil_to_torch(img)]
text_strs += f"{i},{k+1} {caption} \n"
# save end of chunk
if True:
chunk_ind = int(i/(chunk_size))
# print(f'saving chunk to vis/{out_name}/{out_name}_{chunk_ind:03d}.jpg')
torchvision.utils.save_image(torch.cat(all_imgs,dim=0),f'text_query_demo.jpg',nrow=4,normalize=True)
all_imgs = []
text_file = open(f'text_query_demo.txt','w')
text_file.write(text_strs)
text_file.close()
# text_file = open(f'vis/{out_name}/{out_name}_meta_{chunk_ind:03d}.txt','w')
# text_file.write(metadata_strs)
# text_file.close()
text_strs = ''
metadata_strs = ''
import faiss
import torch
from PIL import Image
import open_clip
import torch
import torch.nn.functional as F
from tqdm import tqdm
from open_clip import tokenize
import torchvision
import numpy as np
import os
with torch.no_grad():
index = faiss.read_index('vitb32_overlap_index/image.index')
s = input('type your text query...')
# model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14',pretrained = 'laion2b_s32b_b79k')
import clip
device = torch.device("cuda")
model, preprocess = clip.load("ViT-B/32", device=device, jit=True)
model.cuda()
texts = tokenize([s]).cuda() # tokenize
text_embeddings = model.encode_text(texts)
text_embedding = F.normalize(text_embeddings, dim=-1).mean(dim=0)
text_embedding /= text_embedding.norm()
while True:
print("embeddings finished!")
print(text_embedding.size())
d,nns = index.search(text_embedding.reshape(1,-1).cpu().numpy().astype('float32'),16)
vis_nns_from_nn_inds(nns)
print('done!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment