Skip to content

Instantly share code, notes, and snippets.

View thigm85's full-sized avatar

Thiago G. Martins thigm85

View GitHub Profile
next(iter(image_dataset)).shape
image_dataset = ImageDataset(
img_dir=os.environ["IMG_DIR"],
transform=preprocess
)
next(iter(image_dataset))
image_dataset = ImageDataset(img_dir=os.environ["IMG_DIR"])
import os
import glob
from torch.utils.data import Dataset
class ImageDataset(Dataset):
def __init__(self, img_dir, transform=None):
self.img_dir = img_dir
self.image_file_names = glob.glob(os.path.join(img_dir, "*.jpg"))
self.transform = transform
with torch.no_grad():
image_features = model.encode_image(
processed_images[0].unsqueeze(0)
)
image_features.shape
with torch.no_grad():
image_features = model.encode_image(image_input).float()
image_features.shape
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
image_input = torch.tensor(np.stack(processed_images)).to(device)
image_input.shape
from torchvision.transforms import ToPILImage
plot_pil_images([ToPILImage()(x) for x in processed_images])
processed_images = [preprocess.transforms[4](image) for image in processed_images]