Skip to content

Instantly share code, notes, and snippets.

@afiaka87
Last active March 9, 2022 12:58
Show Gist options
  • Save afiaka87/54295b957f235d0755c626bd42b87137 to your computer and use it in GitHub Desktop.
Save afiaka87/54295b957f235d0755c626bd42b87137 to your computer and use it in GitHub Desktop.
(WIP) Point at a folder of images, get box labels with probs in a folder
#%%
# cd Detic/
# %%
import detectron2
from detectron2.utils.logger import setup_logger
from pathlib import Path
from random import randint, choice
import time
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms as T
setup_logger()
# import some common libraries
import sys
import numpy as np
import os, json, cv2, random
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from IPython.display import display, clear_output
# Detic libraries
sys.path.insert(0, 'third_party/CenterNet2/projects/CenterNet2/')
from centernet.config import add_centernet_config
from detic.config import add_detic_config
from detic.modeling.utils import reset_cls_test
# %%
# Build the detector and download our pretrained weights
cfg = get_cfg()
add_centernet_config(cfg)
add_detic_config(cfg)
cfg.merge_from_file("configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml")
cfg.MODEL.WEIGHTS = 'https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth'
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.3 # set threshold for this model
cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'rand'
cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = False # For better visualization purpose. Set to False for all classes.
predictor = DefaultPredictor(cfg)
# %%
# Setup the model's vocabulary using build-in datasets
BUILDIN_CLASSIFIER = {
'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy',
'objects365': 'datasets/metadata/o365_clip_a+cnamefix.npy',
'openimages': 'datasets/metadata/oid_clip_a+cname.npy',
'coco': 'datasets/metadata/coco_clip_a+cname.npy',
}
BUILDIN_METADATA_PATH = {
'lvis': 'lvis_v1_val',
'objects365': 'objects365_v2_val',
'openimages': 'oid_val_expanded',
'coco': 'coco_2017_val',
}
vocabulary = 'lvis' # change to 'lvis', 'objects365', 'openimages', or 'coco'
metadata = MetadataCatalog.get(BUILDIN_METADATA_PATH[vocabulary])
classifier = BUILDIN_CLASSIFIER[vocabulary]
num_classes = len(metadata.thing_classes)
reset_cls_test(predictor.model, classifier, num_classes)
# %%
class ImageDataset(Dataset):
def __init__(self,
folder,
shuffle=False
):
"""
@param folder: Folder containing images and text files matched by their paths' respective "stem"
@param truncate_captions: Rather than throw an exception, captions which are too long will be truncated.
"""
super().__init__()
self.shuffle = shuffle
path = Path(folder)
image_files = [
*path.glob('**/*.png'), *path.glob('**/*.jpg'),
*path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
]
image_files = {image_file.stem: image_file for image_file in image_files}
keys = set(image_files.keys())
self.keys = list(keys)
self.image_files = {k: v for k, v in image_files.items() if k in keys}
def __len__(self):
return len(self.keys)
def random_sample(self):
return self.__getitem__(randint(0, self.__len__() - 1))
def sequential_sample(self, ind):
if ind >= self.__len__() - 1:
return self.__getitem__(0)
return self.__getitem__(ind + 1)
def skip_sample(self, ind):
if self.shuffle:
return self.random_sample()
return self.sequential_sample(ind=ind)
def __getitem__(self, ind):
key = self.keys[ind]
image_file = self.image_files[key]
resize_value = 512
try:
pil_image = Image.open(image_file).resize((resize_value, resize_value)).convert('RGB')
except OSError:
return self.skip_sample(ind)
np_output = np.array(pil_image)
return np_output, key
def log_class_prediction(caption, ind, outdir='./caption/'):
os.makedirs(outdir, exist_ok=True)
save_path = os.path.join(outdir, f'{ind}.txt')
with open(save_path, "w") as text_file:
text_file.write(f"{caption}")
import torch
batch_size = 64
dataset = ImageDataset(folder='/home/samsepiol/DatasetWorkspace/CurrentDatasets/WIKIART/')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
# %%
with torch.no_grad():
current_time = time.time()
for batch, ind in dataloader:
current_time = time.time()
batch_np = batch.cpu().numpy()
predictions = [predictor(x_np)['instances'] for x_np in batch_np]
for i, prediction in enumerate(predictions):
print(f"{i} {prediction.pred_classes}")
log_class_prediction(prediction, ind[i])
elapsed_time = time.time() - current_time
print(f"Elapsed time: {elapsed_time} for {batch_size}")
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment