Skip to content

Instantly share code, notes, and snippets.

@libfun
Last active July 18, 2016 21:20
Show Gist options
  • Save libfun/5e921008b9179f3ffd20f9e6e44f8b0d to your computer and use it in GitHub Desktop.
Save libfun/5e921008b9179f3ffd20f9e6e44f8b0d to your computer and use it in GitHub Desktop.
categories = ['bags', 'belts', 'dresses', 'eyewear',
'footwear', 'hats', 'leggings', 'outerwear',
'pants', 'skirts', 'tops']
photodict = dict()
train_ids = dict()
test_ids = dict()
retr_ids = list()
train_data = dict()
test_data = dict()
retr_data = dict()
for cat in categories:
with open('meta/json/train_pairs_'+cat+'.json') as file:
train_data[cat] = json.load(file)
for tr in tqdm(train_data[cat]):
prod_id = tr['product']
i = tr['photo']
if os.path.isfile('/data/photoset/'+str(i).zfill(9)+'.jpg'):
statinfo = os.stat('/data/photoset/'+str(i).zfill(9)+'.jpg')
if statinfo.st_size > 5000:
if cat+'@'+str(prod_id) in train_ids.keys():
train_ids[cat+'@'+str(prod_id)][0].append(tr)
else:
train_ids[cat+'@'+str(prod_id)] = [[tr]]
with open('meta/json/test_pairs_'+cat+'.json') as file:
test_data[cat] = json.load(file)
for tr in tqdm(test_data[cat]):
prod_id = tr['product']
i = tr['photo']
if os.path.isfile('/data/photoset/'+str(i).zfill(9)+'.jpg'):
statinfo = os.stat('/data/photoset/'+str(i).zfill(9)+'.jpg')
if statinfo.st_size > 5000:
if cat+'@'+str(prod_id) in test_ids.keys():
test_ids[cat+'@'+str(prod_id)][0].append(tr)
else:
test_ids[cat+'@'+str(prod_id)] = [[tr]]
for cat in categories:
with open('meta/json/retrieval_'+cat+'.json') as file:
retr_data[cat] = json.load(file)
for tr in tqdm(retr_data[cat]):
prod_id = tr['product']
i = tr['photo']
if cat+'@'+str(prod_id) in train_ids.keys():
if os.path.isfile('/data/photoset/'+str(i).zfill(9)+'.jpg'):
statinfo = os.stat('/data/photoset/'+str(i).zfill(9)+'.jpg')
if statinfo.st_size > 5000:
if len(train_ids[cat+'@'+str(prod_id)]) > 1:
train_ids[cat+'@'+str(prod_id)][1].append(tr)
else:
train_ids[cat+'@'+str(prod_id)].append([tr])
elif cat+'@'+str(prod_id) in test_ids.keys():
if os.path.isfile('/data/photoset/'+str(i).zfill(9)+'.jpg'):
statinfo = os.stat('/data/photoset/'+str(i).zfill(9)+'.jpg')
if statinfo.st_size > 5000:
if len(test_ids[cat+'@'+str(prod_id)]) > 1:
test_ids[cat+'@'+str(prod_id)][1].append(tr)
else:
test_ids[cat+'@'+str(prod_id)].append([tr])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment