Skip to content

Instantly share code, notes, and snippets.

@uehara-mech
Created April 22, 2022 09:17
Show Gist options
  • Save uehara-mech/f4c850999f1c90f61851951e1183ef0a to your computer and use it in GitHub Desktop.
Save uehara-mech/f4c850999f1c90f61851951e1183ef0a to your computer and use it in GitHub Desktop.
import json
import os.path
import pickle
from functools import lru_cache
from typing import Dict, Tuple, List, Union, Any
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from maskrcnn_benchmark.data.datasets.vg_tsv import VGTSVDataset
@lru_cache(maxsize=5000)
def load_image(vg_img_id: int):
try:
img = Image.open(f'VG_100K/{vg_img_id}.jpg')
except FileNotFoundError:
img = Image.open(f'VG_100K_2/{vg_img_id}.jpg')
return img.convert("RGB")
class LoadImageDataset(Dataset):
"""
We use torch.utils.data.Dataset to load images and extract features with multi-processing.
"""
def __init__(self, all_vg_img_ids: List[int]) -> None:
self.all_vg_img_ids = all_vg_img_ids
def __len__(self) -> int:
return len(self.all_vg_img_ids)
def __getitem__(self, idx: int) -> Dict[str, Union[int, np.ndarray]]:
vg_img_id = self.all_vg_img_ids[idx]
img = load_image(vg_img_id)
img = np.array(img)
return {
"feature": get_image_feature(img),
"vg_img_id": vg_img_id,
}
def get_image_feature(image: np.ndarray) -> np.ndarray:
"""
In order to save memory, we extract 8-dimensional feature from the image.
Args:
image: (H, W, C)
Returns:
image_feature: (8,)
"""
calc_channel_mean = np.mean(image, axis=(0, 1)) # shape: (3,)
calc_channel_std = np.std(image, axis=(0, 1)) # shape: (3,)
height, width = image.shape[:2]
return np.concatenate([
calc_channel_mean,
calc_channel_std,
np.array([height, width])
]) # shape: (8,)
def collate_load_fn(batch):
return {x["vg_img_id"]: x["feature"] for x in batch}
class FindImageDataset(Dataset):
"""
We use torch.utils.data.Dataset to reduce processing time with the benefit of multi-processing.
"""
def __init__(
self,
vg_tsv_dataset: VGTSVDataset,
feature_array: np.ndarray,
vgid_list: List[int],
split: str
):
self.vg_tsv_dataset = vg_tsv_dataset
self.feature_array = feature_array
self.vgid_list = vgid_list
self.split = split
def __len__(self):
return len(self.vg_tsv_dataset)
def __getitem__(self, idx):
sg_img = np.array(self.vg_tsv_dataset.get_image(idx))
sg_img_lineno = int(self.vg_tsv_dataset.get_line_no(idx))
sg_feat = get_image_feature(sg_img) # shape (8,)
# calculate the similarity between the image array from tsv and the original image in VG
dist = np.linalg.norm(self.feature_array - sg_feat, axis=1)
min_dist_idx = np.argmin(dist)
return {
"vg_img_id": self.vgid_list[min_dist_idx],
"sg_id": idx,
"sg_img_lineno": sg_img_lineno,
"split": self.split,
"dist": dist[min_dist_idx]
}
def collate_fn(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return batch
def find_same_image(
dataset: VGTSVDataset,
vgid_to_sgid: Dict[int, int],
vgid_to_shape: Dict[int, Tuple[int, int]],
) -> Dict[int, int]:
for sg_idx in tqdm(range(len(dataset)), ncols=60):
sg_img = np.array(dataset.get_image(sg_idx))
for each_vgid in vgid_to_shape:
if each_vgid not in vgid_to_sgid:
if sg_img.shape[:2] != vgid_to_shape[each_vgid]:
continue
try:
vg_img = np.array(load_image(each_vgid))
dist = (np.linalg.norm(sg_img - vg_img))
if dist < 0.1:
vgid_to_sgid[each_vgid] = sg_idx
break
except Exception as e:
print(f"{sg_idx} - {each_vgid} is corrupted! {e}")
return vgid_to_sgid
def main():
train_dataset = VGTSVDataset(
'visualgenome/train_danfeiX_relation_nm.yaml'
)
test_dataset = VGTSVDataset(
'visualgenome/test_danfeiX_relation.yaml'
)
print("creating vgid_to_shape")
with open("VisualGenome/image_data.json", 'r') as f:
vg_img_data = json.load(f)
all_vg_img_ids = [int(x['image_id']) for x in vg_img_data]
# extract features
print("extracting features")
extract_feature_dataset = LoadImageDataset(
all_vg_img_ids=all_vg_img_ids
)
extract_feature_loader = DataLoader(
dataset=extract_feature_dataset,
batch_size=32, shuffle=False, num_workers=32,
collate_fn=collate_load_fn
)
vgid_to_feature = {}
for load_batch in tqdm(extract_feature_loader, ncols=60):
vgid_to_feature.update(load_batch)
print("extracting features done")
# feature_array: shape (num_images, 8)
feature_array = np.stack([vgid_to_feature[_vgid] for _vgid in vgid_to_feature])
vgid_list = list(vgid_to_feature.keys())
convert_train = FindImageDataset(
train_dataset,
feature_array=feature_array, vgid_list=vgid_list,
split="train"
)
convert_test = FindImageDataset(
test_dataset,
feature_array=feature_array, vgid_list=vgid_list,
split="test"
)
convert_train_loader = DataLoader(
convert_train, batch_size=32, shuffle=False, num_workers=32, drop_last=False,
collate_fn=collate_fn
)
convert_test_loader = DataLoader(
convert_test, batch_size=32, shuffle=False, num_workers=32, drop_last=False,
collate_fn=collate_fn
)
vgid_to_sginfo = []
for idx, each_batch in enumerate(tqdm(convert_train_loader, ncols=60)):
vgid_to_sginfo.extend(each_batch)
for idx, each_batch in enumerate(tqdm(convert_test_loader, ncols=60)):
vgid_to_sginfo.extend(each_batch)
with open('vgid_to_sgid.pickle', 'wb') as f:
pickle.dump(vgid_to_sginfo, f)
print(f"converted result saved at {os.path.abspath('vgid_to_sgid.pickle')}")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment