Skip to content

Instantly share code, notes, and snippets.

@mitmul
Created June 19, 2018 16:10
Show Gist options
  • Save mitmul/32f6e9b000f1acb8aa118c41afec1a14 to your computer and use it in GitHub Desktop.
Save mitmul/32f6e9b000f1acb8aa118c41afec1a14 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import collections
import os
import shutil
import numpy as np
import pandas as pd
import tabulate
def create_dataset(original_excel_fn):
info = pd.read_excel(original_excel_fn)
labels = []
data = []
product_name_en = {
'ポテトチップス': 'PotatoChips',
}
size_en = {
'BIGBAG': 'BIGBAG',
'Lサイズ': 'L_Size',
'レギュラー': 'Regular'
}
flavor_en = {
'ウスシオ': 'LightSalt',
'コンソメパンチ': 'Consomme',
'キュウシュウショウユ': 'KyushuSoySauce'
}
class_id = {
4901330502911: 0,
4901330502928: 1,
4901330503284: 2,
4901330523121: 3,
4901330523176: 4,
4901330523183: 5,
4901330532734: 6,
4901330532918: 7,
4901330534516: 8
}
for i, row in info.iterrows():
fn, jan, product, _, size, flavor = row
fn = os.path.basename(fn)
head, shape, vertical_angle, horizontal_angle = os.path.splitext(fn)[0].split('_')
if 'D' in vertical_angle:
vertical_angle = -int(vertical_angle.replace('D', ''))
elif 'U' in vertical_angle:
vertical_angle = int(vertical_angle.replace('U', ''))
if str(head) != str(jan):
print(head, '!=', jan)
continue
data.append({
'filename': fn,
'jan': jan,
'class_id': class_id[jan],
'shape_id': shape,
'vertical_angle': vertical_angle,
'horizontal_angle': int(horizontal_angle),
'product_name': product_name_en[product],
'size': size_en[size],
'flavor': flavor_en[flavor]
})
labels = pd.DataFrame(data)
return labels
def mkdir(dname):
if not os.path.exists(dname):
os.makedirs(dname)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--n-shapes-for-train', type=int, default=27)
parser.add_argument('--n-valid-examples', type=int, default=2000)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--output-dir', type=str, default='data')
parser.add_argument('--image-dir', type=str, default='data/images')
parser.add_argument('--original-excel-filename', type=str, default='hackathon data/ファイル情報.xlsx')
args = parser.parse_args()
# Parse the given excel file
labels = create_dataset(args.original_excel_filename)
np.random.seed(args.seed)
n_shapes = len(np.unique(labels.loc[:, 'shape_id']))
n_shapes_for_train = args.n_shapes_for_train
shape_ids = np.arange(1, n_shapes + 1)
np.random.shuffle(shape_ids)
shapes_train = shape_ids[:n_shapes_for_train]
train_dir = os.path.join(args.output_dir, 'train')
mkdir(train_dir)
train_img_dir = os.path.join(train_dir, 'images')
mkdir(train_img_dir)
fp_train_labels = open(os.path.join(train_dir, 'train_labels.txt'), 'w')
valtest_labels = []
train_class_balance = collections.defaultdict(int)
for _, row in labels.iterrows():
if int(row['shape_id']) in shapes_train:
print('{} {}'.format(row['filename'], row['class_id']), file=fp_train_labels)
train_class_balance[row['class_id']] += 1
img_fn = os.path.join(args.image_dir, row['filename'])
shutil.copy(img_fn, train_img_dir)
else:
valtest_labels.append([row['filename'], row['class_id']])
fp_train_labels.close()
print('Train calss balance:')
print(tabulate.tabulate(
train_class_balance.items(), headers=('class_id', 'frequency')))
# Split valtest into val and test
valid_dir = os.path.join(args.output_dir, 'valid')
mkdir(valid_dir)
valid_img_dir = os.path.join(valid_dir, 'images')
mkdir(valid_img_dir)
fp_valid_labels = open(os.path.join(valid_dir, 'valid_labels.txt'), 'w')
test_dir = os.path.join(args.output_dir, 'test')
mkdir(test_dir)
test_img_dir = os.path.join(test_dir, 'images')
mkdir(test_img_dir)
fp_test_labels = open(os.path.join(test_dir, 'test_labels.txt'), 'w')
np.random.shuffle(valtest_labels)
valid_class_balance = collections.defaultdict(int)
test_class_balance = collections.defaultdict(int)
for i, (fn, class_id) in enumerate(valtest_labels):
img_fn = os.path.join(args.image_dir, fn)
if i < args.n_valid_examples:
valid_class_balance[class_id] += 1
if os.path.exists(img_fn):
print('{} {}'.format(fn, class_id), file=fp_valid_labels)
shutil.copy(img_fn, valid_img_dir)
else:
test_class_balance[class_id] += 1
if os.path.exists(img_fn):
print('{} {}'.format(fn, class_id), file=fp_test_labels)
shutil.copy(img_fn, test_img_dir)
print('\nValid class balance:')
print(tabulate.tabulate(
valid_class_balance.items(), headers=('class_id', 'frequency')))
print('\nTest class balance:')
print(tabulate.tabulate(
test_class_balance.items(), headers=('class_id', 'frequency')))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment