Skip to content

Instantly share code, notes, and snippets.

@afiaka87
Created April 4, 2022 14:00
Show Gist options
  • Save afiaka87/493d75b2b7edbadfb9f69aee6152a2da to your computer and use it in GitHub Desktop.
Save afiaka87/493d75b2b7edbadfb9f69aee6152a2da to your computer and use it in GitHub Desktop.
import pathlib
from csv import writer
import torch
import tqdm
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from data import create_loader
from data.utils import save_result
from models.blip import blip_decoder
image_size = 384
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
# TODO change dataset path, batch size
# dataset_path = "/home/afiaka87/datasets/pokemon_256_resize"
# dataset_path = "/home/afiaka87/datasets/artstation_384px"
# batch_size = 256
# nucleus_sampling = True
# print(f"dataset_path: {dataset_path}")
def load_image(image_path, image_size, device):
try:
raw_image = Image.open(image_path).convert('RGB')
w, h = raw_image.size
transform = transforms.Compose([
transforms.Resize((image_size, image_size),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
image = transform(raw_image) # .unsqueeze(0).to(device)
return image
except Exception as e:
print(f"Error loading image {image_path}")
return None
def get_image_files(base_path):
base_path = pathlib.Path(base_path)
return [
*base_path.glob("**/*.png"),
*base_path.glob("**/*.jpg"),
*base_path.glob("**/*.jpeg"),
*base_path.glob("**/*.bmp"),
]
class TextImageDataset(Dataset):
def __init__(
self,
folder="",
image_size=384,
transform=None,
device=None,
):
self.image_size = image_size
self.transform = transform
self.device = device
self.image_files = get_image_files(folder)
print(f"Found {len(self.image_files)} images")
def __len__(self):
return len(self.image_files)
def __getitem__(self, index):
path = self.image_files[index]
image = load_image(path, self.image_size, self.device)
if image is None:
return self.__getitem__(index+1)
if self.transform is not None:
image = self.transform(image)
return {"image": image, "image_path": str(path), }
if __name__ == "__main__":
device = torch.device('cuda:7')
# prompts = []
dataset = TextImageDataset(dataset_path, image_size=image_size, device=device)
dataloader = create_loader(datasets=[dataset], samplers=[None], batch_size=[
batch_size], num_workers=[8], is_trains=[False], collate_fns=[None])[0]
result = []
for prompt in tqdm.tqdm(prompts): # generate multiple captions per image using different starting words
tqdm.tqdm.write(f"Processing prompt: {prompt}")
model = blip_decoder(pretrained=model_url, image_size=image_size,
prompt=prompt, vit='base').to(device).eval()
# model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base').to(device).eval() # or you can just use the default `a photo of `
for idx, (batch) in tqdm.tqdm(enumerate(dataloader), total=len(dataloader)):
images = batch["image"]
image_paths = batch["image_path"]
images = images.to(device)
with torch.no_grad():
captions = model.generate(
images, sample=True, max_length=80, min_length=10, top_p=0.9, repetition_penalty=1.1)
for caption, image_path in zip(captions, image_paths):
with open(image_path.split(".")[0] + ".txt", "a") as f:
tqdm.tqdm.write(f"{prompt}{caption}")
f.write(f"{prompt}{caption}\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment