Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Last active April 9, 2022 17:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krsnewwave/e180193f2007b5cefa248eff8987f6d7 to your computer and use it in GitHub Desktop.
Save krsnewwave/e180193f2007b5cefa248eff8987f6d7 to your computer and use it in GitHub Desktop.
class SquarePad:
def __call__(self, image):
w, h = image.size
max_wh = np.max([w, h])
hp = int((max_wh - w) / 2)
vp = int((max_wh - h) / 2)
padding = (hp, vp, hp, vp)
return torchvision.transforms.functional.pad(image, padding, 0, 'constant')
class ArtPeriodDataSet(Dataset):
def __init__(self, dataframe, img_dir, transform=None, target_transform=None):
self.dataframe = dataframe.copy()
self.img_dir = img_dir
# eliminate non-existing files
print("Eliminate non-existing files")
for idx, row in tqdm(dataframe.iterrows(), total=len(dataframe)):
path = f'{img_dir}/{row["ID"]}.jpg'
if not os.path.exists(path):
self.dataframe.drop(idx, inplace=True)
print("Dropping", path)
# encode brands
self.label_encoder = LabelEncoder()
labels_encoded = self.label_encoder.fit_transform(self.dataframe["period"])
self.dataframe["period_encoded"] = labels_encoded
self.classes = self.label_encoder.classes_
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
img_path = f'{self.img_dir}/{self.dataframe.iloc[idx]["ID"]}.jpg'
image = PIL.Image.open(img_path).convert('RGB')
label = self.dataframe.iloc[idx]["period_encoded"]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment