Last active
April 9, 2023 11:33
-
-
Save A2va/fd96f5d1cdf7972de22483bab1995a04 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import glob | |
import os | |
import random | |
from urllib.parse import urlparse | |
import numpy as np | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
from tqdm import tqdm | |
from transformers import OFAModel, OFATokenizer | |
# turn on cuda if GPU is available | |
use_cuda = torch.cuda.is_available() | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] | |
IMAGE_SIZE = 480 | |
mean = [0.5, 0.5, 0.5] | |
std = [0.5, 0.5, 0.5] | |
IMAGE_TRANSFORM = transforms.Compose([ | |
lambda image: image.convert("RGB"), | |
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), | |
interpolation=InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=mean, std=std), | |
]) | |
class ImageLoadingTransformDataset(torch.utils.data.Dataset): | |
def __init__(self, image_paths): | |
self.images = image_paths | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, idx): | |
img_path = self.images[idx] | |
try: | |
image = Image.open(img_path).convert("RGB") | |
# convert to tensor temporarily so dataloader will accept it | |
tensor = IMAGE_TRANSFORM(image) | |
except Exception as e: | |
print( | |
f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") | |
return None | |
return (tensor, img_path) | |
def collate_fn_remove_corrupted(batch): | |
"""Collate function that allows to remove corrupted examples in the | |
dataloader. It expects that the dataloader returns 'None' when that occurs. | |
The 'None's in the batch are removed. | |
""" | |
# Filter out all the Nones (corrupted examples) | |
batch = list(filter(lambda x: x is not None, batch)) | |
return batch | |
# Function to turn FP32 to FP16 | |
def apply_half(t): | |
if t.dtype is torch.float32: | |
return t.to(dtype=torch.half) | |
return t | |
def glob_images(directory, base="*"): | |
img_paths = [] | |
for ext in IMAGE_EXTENSIONS: | |
if base == '*': | |
img_paths.extend(glob.glob(os.path.join( | |
glob.escape(directory), base + ext))) | |
else: | |
img_paths.extend(glob.glob(glob.escape( | |
os.path.join(directory, base + ext)))) | |
# img_paths = list(set(img_paths)) # 重複を排除 | |
# img_paths.sort() | |
return img_paths | |
def main(args): | |
# fix the seed for reproducibility | |
seed = args.seed # + utils.get_rank() | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
use_fp16 = args.fp16 | |
tokenizer = OFATokenizer.from_pretrained(args.ckpt_dir) | |
txt = " what does the image describe?" | |
inputs = tokenizer([txt], return_tensors="pt").input_ids | |
model = OFAModel.from_pretrained(args.ckpt_dir, use_cache=True) | |
print(f"load images from {args.train_data_dir}") | |
image_paths = glob_images(args.train_data_dir) | |
print(f"found {len(image_paths)} images.") | |
def run_batch(path_imgs): | |
for image_path, patch_img in path_imgs: | |
gen = model.generate(inputs, patch_images=patch_img.unsqueeze(0), | |
num_beams=5, no_repeat_ngram_size=3) | |
caption = tokenizer.batch_decode(gen, skip_special_tokens=True) | |
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: | |
f.write(caption[0] + "\n") | |
if args.debug: | |
print(image_path, caption) | |
# idx += args.num_beams | |
if args.max_data_loader_n_workers is not None: | |
dataset = ImageLoadingTransformDataset(image_paths) | |
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, | |
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) | |
else: | |
data = [[(None, ip)] for ip in image_paths] | |
b_imgs = [] | |
for data_entry in tqdm(data, smoothing=0.0): | |
for data in data_entry: | |
if data is None: | |
continue | |
img_tensor, image_path = data | |
if img_tensor is None: | |
try: | |
raw_image = Image.open(image_path) | |
if raw_image.mode != 'RGB': | |
raw_image = raw_image.convert("RGB") | |
img_tensor = IMAGE_TRANSFORM(raw_image) | |
except Exception as e: | |
print( | |
f"Could not load image path: {image_path}, error: {e}") | |
continue | |
b_imgs.append((image_path, img_tensor)) | |
if len(b_imgs) >= args.batch_size: | |
run_batch(b_imgs) | |
b_imgs.clear() | |
if len(b_imgs) > 0: | |
run_batch(b_imgs) | |
print("done!") | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("train_data_dir", type=str, | |
help="directory for train images") | |
parser.add_argument("--ckpt_dir", type=str, default="ofa-large", | |
help="OFA caption weights (caption_huge_best.pth)") | |
parser.add_argument("--caption_extension", type=str, default=".caption", | |
help="extension of caption file") | |
parser.add_argument("--batch_size", type=int, default=1, | |
help="batch size in inference") | |
parser.add_argument("--max_data_loader_n_workers", type=int, default=None, | |
help="enable image reading by DataLoader with this number of workers (faster)") | |
parser.add_argument("--num_beams", type=int, default=1, | |
help="num of beams in beam search") | |
parser.add_argument("--temperature", type=float, | |
default=1, help="temperature") | |
parser.add_argument("--max_length", type=int, default=16, | |
help="max length of caption") | |
parser.add_argument("--min_length", type=int, default=5, | |
help="min length of caption") | |
parser.add_argument('--seed', default=42, type=int, | |
help='seed for reproducibility') | |
parser.add_argument('--no_repeat_ngram_size', default=3, type=int, help='') | |
parser.add_argument("--fp16", action="store_true", | |
help="inference with fp16") | |
parser.add_argument("--debug", action="store_true", help="debug mode") | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment