Skip to content

Instantly share code, notes, and snippets.

@abhishekkrthakur
Last active June 11, 2019 19:52
Show Gist options
  • Save abhishekkrthakur/1e8c4621500b0ad809f5ce0f502ecc96 to your computer and use it in GitHub Desktop.
Save abhishekkrthakur/1e8c4621500b0ad809f5ce0f502ecc96 to your computer and use it in GitHub Desktop.
finetuning_collections_dataset
from PIL import Image
from torch.utils.data import Dataset
class CollectionsDataset(Dataset):
def __init__(self,
csv_file,
root_dir,
num_classes,
transform=None):
self.data = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
self.num_classes = num_classes
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,
self.data.loc[idx, 'id'] + '.png')
image = Image.open(img_name)
labels = self.data.loc[idx, 'attribute_ids']
labels = labels.split()
label_tensor = torch.zeros(self.num_classes)
for i in labels:
label_tensor[int(i)] = 1
if self.transform:
image = self.transform(image)
return {'image': image,
'labels': label_tensor
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment