Last active
October 22, 2015 13:30
-
-
Save lazykyama/f586419cd72d5312288e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8-unix -*- | |
import os | |
import logging | |
import re | |
from PIL import Image | |
INFILENAME_PATTERN=re.compile(r'([^_]+)_img.png') | |
def load_rawdata(indirs): | |
import glob | |
raw_data_list = [] | |
classid_name_mapping = {} | |
for indir in indirs: | |
files = glob.glob(os.path.join(indir, '*')) | |
logging.debug(files) | |
for fn in files: | |
matched_list = INFILENAME_PATTERN.findall(fn) | |
if len(matched_list) == 0: | |
continue | |
char_name = matched_list[-1] | |
class_id = len(classid_name_mapping) | |
classid_name_mapping[class_id] = char_name | |
with Image.open(fn) as img: | |
raw_data = img.getdata() | |
width, height = raw_data.size | |
raw_data = list(raw_data) | |
data = [raw_data[width*i:width*(i+1)] for i in range(width)] | |
raw_data_list.append([data, class_id]) | |
return raw_data_list, classid_name_mapping |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8-unix -*- | |
import logging | |
import numpy as np | |
import chainer | |
import chainer.functions as F | |
import chainer.optimizers | |
class Network(object): | |
def __init__(self, id_mapping, class_num): | |
self.__id_mapping = id_mapping | |
self.__model = chainer.FunctionSet( | |
conv1=F.Convolution2D(1, 16, 5), | |
conv2=F.Convolution2D(16, 16, 5), | |
l3=F.Linear(784, 784), | |
softmax4=F.Linear(784, class_num)) | |
self.__optimizer = chainer.optimizers.Adam() | |
self.__optimizer.setup(self.__model) | |
def train(self, source, teach, epoch_max=15, batch_size=50): | |
data_size = len(teach) | |
total_loss = 0 | |
total_acc = 0 | |
for epoch in range(epoch_max): | |
# train per minibatch. | |
logging.info('epoch {}.'.format(epoch)) | |
epoch_loss = 0 | |
epoch_acc = 0 | |
perm = np.random.permutation(data_size) | |
for i in range(0, data_size, batch_size): | |
source_batch = source[perm[i:i+batch_size]] | |
teach_batch = teach[perm[i:i+batch_size]] | |
self.__optimizer.zero_grads() | |
result = self.__forward(source_batch) | |
t = chainer.Variable(teach_batch) | |
loss = F.softmax_cross_entropy(result, t) | |
acc = F.accuracy(result, t) | |
loss.backward() | |
self.__optimizer.update() | |
teach_batch_size = len(teach_batch) | |
epoch_loss += float(chainer.cuda.to_cpu(loss.data)) * teach_batch_size | |
epoch_acc += float(chainer.cuda.to_cpu(acc.data)) * teach_batch_size | |
logging.info('mean loss: {}, mean acc: {}'.format( | |
epoch_loss / data_size, epoch_acc / data_size)) | |
total_loss += epoch_loss / data_size | |
total_acc += epoch_acc / data_size | |
return total_loss / epoch_max, total_acc / epoch_max | |
def predict(self, data): | |
return np.argmax( | |
F.softmax(self.__forward(data, train=False)).data, axis=1) | |
def get_charname(self, char_class_id): | |
return self.__id_mapping[char_class_id] | |
def __forward(self, source, train=True): | |
x = chainer.Variable(source) | |
h = F.max_pooling_2d(F.relu(self.__model.conv1(x)), ksize=2, stride=2) | |
h = F.max_pooling_2d(F.relu(self.__model.conv2(h)), ksize=3, stride=4) | |
h = F.dropout(F.relu(self.__model.l3(h)), train=train) | |
y = self.__model.softmax4(h) | |
return y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8-unix -*- | |
"""tests single character classifier. | |
Usage: | |
test.py [--model MODEL] [-v] <INDIR> | |
test.py -h|--help | |
test.py --version | |
Options: | |
<INDIR> test dataset dirname. | |
--model MODEL trained model filename [default: ./model.pickle]. | |
-h --help show this help message. | |
--version show this script version. | |
-v --verbose logging level [default: False]. | |
""" | |
import os | |
import sys | |
import logging | |
import docopt | |
import schema | |
import numpy as np | |
import dataloader | |
import network | |
def test_model(data, model): | |
source, teach = list(zip(*data)) | |
source = np.array(source, dtype=np.float32) | |
logging.info('#data: {}'.format(len(teach))) | |
answers = model.predict(source) | |
return answers, teach | |
def __convert_data(rawdata): | |
data = [] | |
for img, class_id in rawdata: | |
img = np.array(img) | |
max_val = np.max(img) | |
min_val = np.min(img) | |
val_range = max_val - min_val | |
scaled_img = (img + min_val) / val_range | |
data.append( | |
[np.array([scaled_img], dtype=np.float32), | |
class_id]) | |
return data | |
def __init_args(): | |
args = docopt.docopt(__doc__, version='0.0.1') | |
try: | |
s = schema.Schema({ | |
'<INDIR>': os.path.exists, | |
'--model': schema.Use(open), | |
'--help': bool, | |
'--version': bool, | |
'--verbose': bool | |
}) | |
args = s.validate(args) | |
except schema.SchemaError as e: | |
sys.stderr.write(e) | |
sys.exit(1) | |
if args['--verbose']: | |
# debug. | |
logging.basicConfig(level=logging.DEBUG, | |
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s') | |
else: | |
# release. | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s [%(levelname)s] %(message)s') | |
logging.debug('{}'.format(args)) | |
return args | |
def __same_name(ans, correct): | |
return ans.lower().endswith(correct) | |
def __main(): | |
import pickle | |
args = __init_args() | |
raw_data, id_mapping = dataloader.load_rawdata([args['<INDIR>']]) | |
data = __convert_data(raw_data) | |
model = pickle.load(args['--model']) | |
answers, charnames = test_model(data, model) | |
result = list(zip(answers, charnames)) | |
correct_count = 0 | |
for ans, chn in result: | |
ans_name = model.get_charname(ans) | |
correct_name = id_mapping[chn] | |
correct_msg = 'WRONG' | |
if __same_name(ans_name, correct_name): | |
correct_count += 1 | |
correct_msg = 'RIGHT' | |
logging.info('correct: {}, answer: {} => {}'.format( | |
correct_name, ans_name, correct_msg)) | |
logging.info('test accuracy: {} ({} / {})'.format( | |
(correct_count / float(len(result))), correct_count, len(result))) | |
return True | |
if __name__ == '__main__': | |
if __main(): | |
sys.exit(0) | |
else: | |
sys.exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8-unix -*- | |
"""trains single character classifier. | |
Usage: | |
train.py [-o OUTPUT] [--epoch EPOCH] [--batch-size SIZE] [-v] <INDIR>... | |
train.py -h|--help | |
train.py --version | |
Options: | |
<INDIR...> input dirname. | |
-o --output OUTPUT output filename [default: ./model.pickle]. | |
--epoch EPOCH epoch size [default: 15]. | |
--batch-size SIZE mini-batch size [default: 50]. | |
-h --help show this help message. | |
--version show this script version. | |
-v --verbose logging level [default: False]. | |
""" | |
import os | |
import sys | |
import logging | |
import docopt | |
import schema | |
import numpy as np | |
import dataloader | |
import network | |
def train_model(data, id_mapping, | |
epoch=15, batch_size=50): | |
source, teach = list(zip(*data)) | |
source = np.array(source, dtype=np.float32) | |
teach = np.array(teach, dtype=np.int32) | |
logging.info('#data: {}'.format(len(teach))) | |
logging.info('#class: {}'.format(len(id_mapping))) | |
nn = network.Network(id_mapping, len(id_mapping)) | |
loss, acc = nn.train(source, teach, epoch_max=epoch, batch_size=batch_size) | |
logging.info( | |
'average loss: {}, accuracy: {}'.format(loss, acc)) | |
return nn | |
def __boost_trainingdata(raw_data, noise_loop=30): | |
data = [] | |
for img, class_id in raw_data: | |
img = np.array(img) | |
for i in range(noise_loop): | |
noise = np.random.normal(0, 1, img.shape) * 32 | |
noised_img = img + noise | |
max_val = np.max(noised_img) | |
min_val = np.min(noised_img) | |
val_range = max_val - min_val | |
noised_img = (noised_img + min_val) / val_range | |
data.append( | |
[np.array([noised_img], dtype=np.float32), | |
class_id]) | |
return data | |
def __init_args(): | |
args = docopt.docopt(__doc__, version='0.0.1') | |
try: | |
s = schema.Schema({ | |
'<INDIR>': [os.path.exists], | |
'--output': str, | |
'--epoch': schema.And(schema.Use(int), lambda e: e>=0, | |
error='negative epoch is invalid: {}'.format(args['--epoch'])), | |
'--batch-size': schema.And(schema.Use(int), lambda b: b>=0, | |
error='negative mini-batch size is invalid: {}'.format( | |
args['--batch-size'])), | |
'--help': bool, | |
'--version': bool, | |
'--verbose': bool | |
}) | |
args = s.validate(args) | |
except schema.SchemaError as e: | |
sys.stderr.write(e) | |
sys.exit(1) | |
if args['--verbose']: | |
# debug. | |
logging.basicConfig(level=logging.DEBUG, | |
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s') | |
else: | |
# release. | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s [%(levelname)s] %(message)s') | |
logging.debug('{}'.format(args)) | |
return args | |
def __main(): | |
import pickle | |
args = __init_args() | |
raw_data, id_mapping = dataloader.load_rawdata(args['<INDIR>']) | |
data = __boost_trainingdata(raw_data) | |
model = train_model(data, id_mapping, | |
epoch=args['--epoch'], batch_size=args['--batch-size']) | |
with open(args['--output'], 'wb') as f: | |
pickle.dump(model, f, -1) | |
return True | |
if __name__ == '__main__': | |
if __main(): | |
sys.exit(0) | |
else: | |
sys.exit(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment