-
-
Save uehara-mech/f4c850999f1c90f61851951e1183ef0a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os.path | |
import pickle | |
from functools import lru_cache | |
from typing import Dict, Tuple, List, Union, Any | |
import numpy as np | |
from PIL import Image | |
from torch.utils.data import Dataset, DataLoader | |
from tqdm import tqdm | |
from maskrcnn_benchmark.data.datasets.vg_tsv import VGTSVDataset | |
@lru_cache(maxsize=5000) | |
def load_image(vg_img_id: int): | |
try: | |
img = Image.open(f'VG_100K/{vg_img_id}.jpg') | |
except FileNotFoundError: | |
img = Image.open(f'VG_100K_2/{vg_img_id}.jpg') | |
return img.convert("RGB") | |
class LoadImageDataset(Dataset): | |
""" | |
We use torch.utils.data.Dataset to load images and extract features with multi-processing. | |
""" | |
def __init__(self, all_vg_img_ids: List[int]) -> None: | |
self.all_vg_img_ids = all_vg_img_ids | |
def __len__(self) -> int: | |
return len(self.all_vg_img_ids) | |
def __getitem__(self, idx: int) -> Dict[str, Union[int, np.ndarray]]: | |
vg_img_id = self.all_vg_img_ids[idx] | |
img = load_image(vg_img_id) | |
img = np.array(img) | |
return { | |
"feature": get_image_feature(img), | |
"vg_img_id": vg_img_id, | |
} | |
def get_image_feature(image: np.ndarray) -> np.ndarray: | |
""" | |
In order to save memory, we extract 8-dimensional feature from the image. | |
Args: | |
image: (H, W, C) | |
Returns: | |
image_feature: (8,) | |
""" | |
calc_channel_mean = np.mean(image, axis=(0, 1)) # shape: (3,) | |
calc_channel_std = np.std(image, axis=(0, 1)) # shape: (3,) | |
height, width = image.shape[:2] | |
return np.concatenate([ | |
calc_channel_mean, | |
calc_channel_std, | |
np.array([height, width]) | |
]) # shape: (8,) | |
def collate_load_fn(batch): | |
return {x["vg_img_id"]: x["feature"] for x in batch} | |
class FindImageDataset(Dataset): | |
""" | |
We use torch.utils.data.Dataset to reduce processing time with the benefit of multi-processing. | |
""" | |
def __init__( | |
self, | |
vg_tsv_dataset: VGTSVDataset, | |
feature_array: np.ndarray, | |
vgid_list: List[int], | |
split: str | |
): | |
self.vg_tsv_dataset = vg_tsv_dataset | |
self.feature_array = feature_array | |
self.vgid_list = vgid_list | |
self.split = split | |
def __len__(self): | |
return len(self.vg_tsv_dataset) | |
def __getitem__(self, idx): | |
sg_img = np.array(self.vg_tsv_dataset.get_image(idx)) | |
sg_img_lineno = int(self.vg_tsv_dataset.get_line_no(idx)) | |
sg_feat = get_image_feature(sg_img) # shape (8,) | |
# calculate the similarity between the image array from tsv and the original image in VG | |
dist = np.linalg.norm(self.feature_array - sg_feat, axis=1) | |
min_dist_idx = np.argmin(dist) | |
return { | |
"vg_img_id": self.vgid_list[min_dist_idx], | |
"sg_id": idx, | |
"sg_img_lineno": sg_img_lineno, | |
"split": self.split, | |
"dist": dist[min_dist_idx] | |
} | |
def collate_fn(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
return batch | |
def find_same_image( | |
dataset: VGTSVDataset, | |
vgid_to_sgid: Dict[int, int], | |
vgid_to_shape: Dict[int, Tuple[int, int]], | |
) -> Dict[int, int]: | |
for sg_idx in tqdm(range(len(dataset)), ncols=60): | |
sg_img = np.array(dataset.get_image(sg_idx)) | |
for each_vgid in vgid_to_shape: | |
if each_vgid not in vgid_to_sgid: | |
if sg_img.shape[:2] != vgid_to_shape[each_vgid]: | |
continue | |
try: | |
vg_img = np.array(load_image(each_vgid)) | |
dist = (np.linalg.norm(sg_img - vg_img)) | |
if dist < 0.1: | |
vgid_to_sgid[each_vgid] = sg_idx | |
break | |
except Exception as e: | |
print(f"{sg_idx} - {each_vgid} is corrupted! {e}") | |
return vgid_to_sgid | |
def main(): | |
train_dataset = VGTSVDataset( | |
'visualgenome/train_danfeiX_relation_nm.yaml' | |
) | |
test_dataset = VGTSVDataset( | |
'visualgenome/test_danfeiX_relation.yaml' | |
) | |
print("creating vgid_to_shape") | |
with open("VisualGenome/image_data.json", 'r') as f: | |
vg_img_data = json.load(f) | |
all_vg_img_ids = [int(x['image_id']) for x in vg_img_data] | |
# extract features | |
print("extracting features") | |
extract_feature_dataset = LoadImageDataset( | |
all_vg_img_ids=all_vg_img_ids | |
) | |
extract_feature_loader = DataLoader( | |
dataset=extract_feature_dataset, | |
batch_size=32, shuffle=False, num_workers=32, | |
collate_fn=collate_load_fn | |
) | |
vgid_to_feature = {} | |
for load_batch in tqdm(extract_feature_loader, ncols=60): | |
vgid_to_feature.update(load_batch) | |
print("extracting features done") | |
# feature_array: shape (num_images, 8) | |
feature_array = np.stack([vgid_to_feature[_vgid] for _vgid in vgid_to_feature]) | |
vgid_list = list(vgid_to_feature.keys()) | |
convert_train = FindImageDataset( | |
train_dataset, | |
feature_array=feature_array, vgid_list=vgid_list, | |
split="train" | |
) | |
convert_test = FindImageDataset( | |
test_dataset, | |
feature_array=feature_array, vgid_list=vgid_list, | |
split="test" | |
) | |
convert_train_loader = DataLoader( | |
convert_train, batch_size=32, shuffle=False, num_workers=32, drop_last=False, | |
collate_fn=collate_fn | |
) | |
convert_test_loader = DataLoader( | |
convert_test, batch_size=32, shuffle=False, num_workers=32, drop_last=False, | |
collate_fn=collate_fn | |
) | |
vgid_to_sginfo = [] | |
for idx, each_batch in enumerate(tqdm(convert_train_loader, ncols=60)): | |
vgid_to_sginfo.extend(each_batch) | |
for idx, each_batch in enumerate(tqdm(convert_test_loader, ncols=60)): | |
vgid_to_sginfo.extend(each_batch) | |
with open('vgid_to_sgid.pickle', 'wb') as f: | |
pickle.dump(vgid_to_sginfo, f) | |
print(f"converted result saved at {os.path.abspath('vgid_to_sgid.pickle')}") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment