Skip to content

Instantly share code, notes, and snippets.

@ayush714
Last active January 28, 2022 13:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ayush714/855c4a0000e6f33ee8efb867f2dc29b3 to your computer and use it in GitHub Desktop.
Save ayush714/855c4a0000e6f33ee8efb867f2dc29b3 to your computer and use it in GitHub Desktop.
import os
import glob
from ArgsHandler import parse_arguments
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision.io import read_image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from os.path import join as pjoin
IMAGES_DIR = "../data/tiny-imagenet-200/tiny-imagenet-200/val/images"
class TrainTinyImageNetDataset(Dataset):
def __init__(self, id, transform=None):
self.filenames = glob.glob(
r"E:\Production Projects\ZenML\data\tiny-imagenet-200\tiny-imagenet-200\train\*\*\*.JPEG"
)
self.transform = transforms.Compose([transforms.ToTensor()])
self.id_dict = id
self.file_object = open("../logs/DataIngestionLogs.txt", "a+")
self.logger = CustomApplicationLogger()
def __len__(self):
return len(self.filenames)
def __getitem__(self, index):
try:
img_path = self.filenames[index]
image = read_image(img_path)
if image.shape[0] == 1:
image = torch.cat((image, image, image), 0)
label = self.id_dict[img_path.split("/")[3]]
if self.transform:
image = self.transform(image.type(torch.FloatTensor))
return self.transform(image), self.transform(label)
except Exception as e:
self.logger.log(
self.file_object, f"Exception occured while getting data: {e}"
)
self.logger.log(self.file_object, f"Index: {index}")
self.logger.log(self.file_object, f"Image name: {self.filenames[index]}")
# self.logger.log(
# self.file_object,
# f"Label: {self.id_dict[self.filenames[index].split('/')[-2]]}",
# )
ANNOTATIONS_PATH = (
r"..\data\tiny-imagenet-200\tiny-imagenet-200\val\val_annotations.txt"
)
class TestTinyImageNetDataset(Dataset):
def __init__(self, id, transform=None):
self.filenames = glob.glob("./ZenML/data/tiny-imagenet-200/val/*/*/*.JPEG")
self.transform = transform
self.id_dict = id
self.cls_dic = {}
# for i, line in enumerate(
# open(
# "r",
# )
# ):
# a = line.split("\t")
# img, cls_id = a[0], a[1]
# self.cls_dic[img] = self.id_dict[cls_id]
self.load_annotations() # run again?
# import ipdb
# ipdb.set_trace()
self.file_object = open("../logs/DataIngestionLogs.txt", "a+")
self.logger = CustomApplicationLogger()
def load_annotations(self):
with open(ANNOTATIONS_PATH, "r") as f:
lines = [
line.split("\t")[:2] for line in f.read().split("\n") if len(line) > 1
]
self.cls_dic = {img: self.id_dict[cls_id] for img, cls_id in lines}
self.filenames = [pjoin(IMAGES_DIR, img) for img in self.cls_dic]
def __len__(self):
return len(self.filenames)
def __getitem__(self, index):
img_path = self.filenames[index]
image = read_image(img_path)
if image.shape[0] == 1:
image = torch.cat((image, image, image), 0)
label = self.cls_dic[img_path.split("/")[-1]]
if self.transform:
image = self.transform(image.type(torch.FloatTensor))
return image, label
# except Exception as e:
# self.logger.log(
# self.file_object, f"Exception occured while getting data: {e}"
# )
# self.logger.log(self.file_object, f"Index: {index}")
# self.logger.log(self.file_object, f"Image name: {self.filenames[index]}")
# self.logger.log(
# self.file_object,
# f"Label: {self.id_dict[self.filenames[index].split('/')[-2]]}",
# )
# return None, None
class DataLoaders:
def __init__(self) -> None:
self.file_object = open("../logs/DataIngestionLogs.txt", "a+")
self.logger = CustomApplicationLogger()
self.id_dict = {}
self.train_transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((200, 200)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=45),
transforms.ToTensor(),
]
)
self.test_transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((200, 200)),
transforms.ToTensor(),
]
)
for i, line in enumerate(
open(
r"E:\Production Projects\ZenML\data\tiny-imagenet-200\tiny-imagenet-200\wnids.txt",
"r",
)
):
self.id_dict[line.replace("\n", "")] = i
self.train_set = TrainTinyImageNetDataset(
id=self.id_dict, transform=self.train_transform
)
self.test_set = TestTinyImageNetDataset(
id=self.id_dict, transform=self.test_transform
)
def get_train_loader(self):
train_loader = DataLoader(self.train_set, batch_size=150, shuffle=True)
return train_loader
def get_test_loader(self):
test_loader = DataLoader(self.test_set, batch_size=150, shuffle=True)
return test_loader
if __name__ == "__main__":
# DataIngest = DataIngestion()
# # DataIngest.download_data()
# DataIngest.unzip_data()
data_loader = DataLoaders()
train_loader = data_loader.get_train_loader()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment