Skip to content

Instantly share code, notes, and snippets.

@khrlimam
Forked from nunenuh/siamese_dataset_example.py
Last active April 4, 2019 08:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save khrlimam/ac0966b974cd1d00f4b1adbc17c4166a to your computer and use it in GitHub Desktop.
Save khrlimam/ac0966b974cd1d00f4b1adbc17c4166a to your computer and use it in GitHub Desktop.
import random
import torch
import pathlib
import os
from torch.utils import data
import PIL
import PIL.Image
from collections import defaultdict
from bisect import insort_right
class SiameseDataset(data.Dataset):
def __init__(self, root, ext, transform=None, pair_transform=None, target_transform=None):
super(SiameseDataset, self).__init__()
self.transform = transform
self.pair_transform = pair_transform
self.target_transform = target_transform
self.root = root
self.base_path = pathlib.Path(root)
self.files = sorted(list(self.base_path.glob("*/*."+ext)))
self.files_map = self._files_mapping()
self.pair_files = self._pair_files()
def __len__(self):
return len(self.pair_files)
def __getitem__(self, idx):
(imp1, imp2), sim = self.pair_files[idx]
im1 = PIL.Image.open(imp1)
im2 = PIL.Image.open(imp2)
if self.transform:
im1 = self.transform(im1)
im2 = self.transform(im2)
if self.pair_transform:
im1,im2 = self.transform_pair(im1,im2)
if self.target_transform:
sim = self.target_transform(sim)
return im1, im2, sim
def _files_mapping(self):
dct = defaultdict(list)
for f in self.files:
dirname = f.parent.name
filename = f.name
insort_right(dct[dirname], filename)
return dct
def _similar_pair(self):
fmap = self.files_map
atp = defaultdict(list)
for key in fmap.keys():
n = len(fmap[key])
for i in range(n):
for j in range(n):
fp = os.path.join(key, fmap[key][i])
fo = os.path.join(key, fmap[key][j])
atp[key].append(((fp,fo),0))
return atp
def _len_similar_pair(self):
spair = self._similar_pair()
return {key: len(spair[key]) for key in spair}
def _diff_pair_dircomp(self):
fmap = self.files_map
dirname = list(fmap.keys())
pair_dircomp=[]
for idx in range(len(dirname)):
dirtmp = dirname.copy()
dirtmp.pop(idx)
odir = dirtmp
pdir = dirname[idx]
pdc = (pdir, odir)
pair_dircomp.append(pdc)
return pair_dircomp
def _different_pair(self):
fmap = self.files_map
pair_sampled = defaultdict(list)
pair_dircomp = self._diff_pair_dircomp()
len_spair = self._len_similar_pair()
for idx, (kp,kvo) in enumerate(pair_dircomp):
val_pri = fmap[kp]
num_sample = len(val_pri)//4
for vp in val_pri:
#get filename file primary
fp = os.path.join(kp,vp)
for ko in kvo:
vov = fmap[ko]
pair=[]
for vo in vov:
fo = os.path.join(ko,vo)
pair.append(((fp, fo),1))
mout = random.sample(pair,num_sample)
pair_sampled[kp].append(mout)
for key in pair_sampled.keys():
val = pair_sampled[key]
num_sample =len_spair[key]
tmp_val = []
for va in val:
for v in va:
tmp_val.append(v)
pair_sampled[key] = random.sample(tmp_val,num_sample)
return pair_sampled
def _pair_files(self):
fmap = self.files_map
base_path = self.root
sim_pair = self._similar_pair()
diff_pair = self._different_pair()
files_list = []
for key in fmap.keys():
spair = sim_pair[key]
dpair = diff_pair[key]
n = len(spair)
for i in range(n):
spair_p = os.path.join(base_path,spair[i][0][0])
spair_o = os.path.join(base_path,spair[i][0][1])
spair[i] = ((spair_p, spair_o), 0)
dpair_p = os.path.join(base_path, dpair[i][0][0])
dpair_o = os.path.join(base_path, dpair[i][0][1])
dpair[i] = ((dpair_p, dpair_o), 1)
files_list.append(spair[i])
files_list.append(dpair[i])
return files_list
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment