Skip to content

Instantly share code, notes, and snippets.

@avijit9
Created February 14, 2018 05:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save avijit9/772aef4eb030c0a3cd20f0bf35b26843 to your computer and use it in GitHub Desktop.
Save avijit9/772aef4eb030c0a3cd20f0bf35b26843 to your computer and use it in GitHub Desktop.
class LandMarkRecognition(Dataset):
def __init__(self, root_dir, csv_file, transform=None):
self.landmarks_csv = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
self.image_names = [i for i in sorted(os.listdir(self.root_dir)) if i.endswith('.jpg')]
# print(len(self.image_names))
def __len__(self):
return len(self.image_names)
def __getitem__(self, idx):
# print("index: %d, size: %d" %(idx, len(self.image_names)))
img_name = os.path.join(self.root_dir, self.image_names[idx])
# image = io.imread(img_name)
# image = Image.fromarray(np.uint8(image)).convert('RGB')
# pdb.set_trace()
image = Image.open(img_name).convert('RGB')
index = self.landmarks_csv.index[self.landmarks_csv['id'] == img_name.split('.')[0].split('/')[-1]][0]
landmarks = self.landmarks_csv.iloc[index]['landmark_id']
landmarks = landmarks.astype('float')
# sample = {'image': image, 'landmarks': landmarks}
if self.transform:
img = self.transform(image)
return img, landmarks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment