Skip to content

Instantly share code, notes, and snippets.

@A2va
Last active April 9, 2023 11:33
Show Gist options
  • Save A2va/fd96f5d1cdf7972de22483bab1995a04 to your computer and use it in GitHub Desktop.
Save A2va/fd96f5d1cdf7972de22483bab1995a04 to your computer and use it in GitHub Desktop.
import argparse
import glob
import os
import random
from urllib.parse import urlparse
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from transformers import OFAModel, OFATokenizer
# turn on cuda if GPU is available
use_cuda = torch.cuda.is_available()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
IMAGE_SIZE = 480
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
IMAGE_TRANSFORM = transforms.Compose([
lambda image: image.convert("RGB"),
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(
f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
# Function to turn FP32 to FP16
def apply_half(t):
if t.dtype is torch.float32:
return t.to(dtype=torch.half)
return t
def glob_images(directory, base="*"):
img_paths = []
for ext in IMAGE_EXTENSIONS:
if base == '*':
img_paths.extend(glob.glob(os.path.join(
glob.escape(directory), base + ext)))
else:
img_paths.extend(glob.glob(glob.escape(
os.path.join(directory, base + ext))))
# img_paths = list(set(img_paths)) # 重複を排除
# img_paths.sort()
return img_paths
def main(args):
# fix the seed for reproducibility
seed = args.seed # + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
use_fp16 = args.fp16
tokenizer = OFATokenizer.from_pretrained(args.ckpt_dir)
txt = " what does the image describe?"
inputs = tokenizer([txt], return_tensors="pt").input_ids
model = OFAModel.from_pretrained(args.ckpt_dir, use_cache=True)
print(f"load images from {args.train_data_dir}")
image_paths = glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
def run_batch(path_imgs):
for image_path, patch_img in path_imgs:
gen = model.generate(inputs, patch_images=patch_img.unsqueeze(0),
num_beams=5, no_repeat_ngram_size=3)
caption = tokenizer.batch_decode(gen, skip_special_tokens=True)
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(caption[0] + "\n")
if args.debug:
print(image_path, caption)
# idx += args.num_beams
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingTransformDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
img_tensor, image_path = data
if img_tensor is None:
try:
raw_image = Image.open(image_path)
if raw_image.mode != 'RGB':
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(
f"Could not load image path: {image_path}, error: {e}")
continue
b_imgs.append((image_path, img_tensor))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str,
help="directory for train images")
parser.add_argument("--ckpt_dir", type=str, default="ofa-large",
help="OFA caption weights (caption_huge_best.pth)")
parser.add_argument("--caption_extension", type=str, default=".caption",
help="extension of caption file")
parser.add_argument("--batch_size", type=int, default=1,
help="batch size in inference")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster)")
parser.add_argument("--num_beams", type=int, default=1,
help="num of beams in beam search")
parser.add_argument("--temperature", type=float,
default=1, help="temperature")
parser.add_argument("--max_length", type=int, default=16,
help="max length of caption")
parser.add_argument("--min_length", type=int, default=5,
help="min length of caption")
parser.add_argument('--seed', default=42, type=int,
help='seed for reproducibility')
parser.add_argument('--no_repeat_ngram_size', default=3, type=int, help='')
parser.add_argument("--fp16", action="store_true",
help="inference with fp16")
parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment