Skip to content

Instantly share code, notes, and snippets.

@e96031413
Created July 30, 2021 04:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e96031413/3af2e097ea42ad73603cfa9183b83e75 to your computer and use it in GitHub Desktop.
Save e96031413/3af2e097ea42ad73603cfa9183b83e75 to your computer and use it in GitHub Desktop.
import os
import glob
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def denorm(tensor, device):
std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1).to(device)
mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1).to(device)
res = torch.clamp(tensor * std + mean, 0, 1)
return res
class PreprocessDataset(Dataset):
def __init__(self, content_dataframe, style_dataframe, transform):
self.content_dataframe = content_dataframe
self.style_dataframe = style_dataframe
self.transform = transform
@staticmethod
def _resize(image):
H, W = image.size
if H < W:
ratio = W / H
H = 512
W = int(ratio * H)
else:
ratio = H / W
W = 512
H = int(ratio * W)
img = image.resize((H, W), Image.ANTIALIAS)
return image
def __len__(self):
return len(self.style_dataframe)
#return len(self.images_pairs)
def __getitem__(self, index):
content_row = self.content_dataframe.iloc[index]
style_row = self.style_dataframe.iloc[index]
content_image = self._resize(Image.open((content_row["file_path"])).convert('RGB'))
style_image = self._resize(Image.open((style_row["file_path"])).convert('RGB'))
content_image = self.transform(content_image)
style_image = self.transform(style_image)
return content_image, style_image
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment