Instantly share code, notes, and snippets.

Embed
What would you like to do?
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
sys.path.insert(0, "../../python")
import mxnet as mx
import numpy as np
import cv2, random
from io import BytesIO
from captcha.image import ImageCaptcha
class OCRBatch(object):
def __init__(self, data_names, data, label_names, label):
self.data = data
self.label = label
self.data_names = data_names
self.label_names = label_names
@property
def provide_data(self):
return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
@property
def provide_label(self):
return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
def gen_rand():
num = random.randint(0, 9999)
buf = str(num)
while len(buf) < 4:
buf = "0" + buf
return buf
def get_label(buf):
return np.array([int(x) for x in buf])
class OCRIter(mx.io.DataIter):
def __init__(self, count, batch_size, num_label, height, width):
super(OCRIter, self).__init__()
self.captcha = ImageCaptcha(fonts=['./data/OpenSans-Regular.ttf'])
self.batch_size = batch_size
self.count = count
self.height = height
self.width = width
self.provide_data = [('data', (batch_size, 3, height, width))]
self.provide_label = [('softmax_label', (self.batch_size, num_label))]
def __iter__(self):
for k in range(self.count / self.batch_size):
data = []
label = []
for i in range(self.batch_size):
num = gen_rand()
img = self.captcha.generate(num)
img = np.fromstring(img.getvalue(), dtype='uint8')
img = cv2.imdecode(img, cv2.IMREAD_COLOR)
img = cv2.resize(img, (self.width, self.height))
cv2.imwrite("./tmp" + str(i % 10) + ".png", img)
img = np.multiply(img, 1/255.0)
img = img.transpose(2, 0, 1)
data.append(img)
label.append(get_label(num))
data_all = [mx.nd.array(data)]
label_all = [mx.nd.array(label)]
data_names = ['data']
label_names = ['softmax_label']
data_batch = OCRBatch(data_names, data_all, label_names, label_all)
yield data_batch
def reset(self):
pass
def get_ocrnet():
data = mx.symbol.Variable('data')
label = mx.symbol.Variable('softmax_label')
conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=32)
pool1 = mx.symbol.Pooling(data=conv1, pool_type="max", kernel=(2,2), stride=(1, 1))
relu1 = mx.symbol.Activation(data=pool1, act_type="relu")
conv2 = mx.symbol.Convolution(data=relu1, kernel=(5,5), num_filter=32)
pool2 = mx.symbol.Pooling(data=conv2, pool_type="avg", kernel=(2,2), stride=(1, 1))
relu2 = mx.symbol.Activation(data=pool2, act_type="relu")
conv3 = mx.symbol.Convolution(data=relu2, kernel=(3,3), num_filter=32)
pool3 = mx.symbol.Pooling(data=conv3, pool_type="avg", kernel=(2,2), stride=(1, 1))
relu3 = mx.symbol.Activation(data=pool3, act_type="relu")
flatten = mx.symbol.Flatten(data = relu3)
fc1 = mx.symbol.FullyConnected(data = flatten, num_hidden = 512)
fc21 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
fc22 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
fc23 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
fc24 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
fc2 = mx.symbol.Concat(*[fc21, fc22, fc23, fc24], dim = 0)
label = mx.symbol.transpose(data = label)
label = mx.symbol.Reshape(data = label, target_shape = (0, ))
return mx.symbol.SoftmaxOutput(data = fc2, label = label, name = "softmax")
def Accuracy(label, pred):
label = label.T.reshape((-1, ))
hit = 0
total = 0
for i in range(pred.shape[0] / 4):
ok = True
for j in range(4):
k = i * 4 + j
if np.argmax(pred[k]) != int(label[k]):
ok = False
break
if ok:
hit += 1
total += 1
return 1.0 * hit / total
network = get_ocrnet()
devs = [mx.gpu(0)]
model = mx.model.FeedForward(ctx = devs,
symbol = network,
num_epoch = 15,
learning_rate = 0.001,
wd = 0.00001,
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
momentum = 0.9)
data_train = OCRIter(100000, 50, 4, 30, 80)
data_test = OCRIter(1000, 50, 4, 30, 80)
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
model.fit(X = data_train, eval_data = data_test, eval_metric = Accuracy, batch_end_callback=mx.callback.Speedometer(32, 50),)
@zimenglan-sysu-512

This comment has been minimized.

zimenglan-sysu-512 commented May 22, 2016

can you show the pic of your network?

@thirdwing

This comment has been minimized.

thirdwing commented Oct 15, 2016

It should be something like this:

ocr_3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment