Created
April 5, 2017 13:37
-
-
Save reyoung/21ecaa4c7bca9943352a40d0ce59f9bc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import paddle.v2 as paddle | |
import gzip | |
def main(): | |
paddle.init(use_gpu=False, trainer_count=1) | |
# define network topology | |
images = paddle.layer.data( | |
name='pixel', type=paddle.data_type.dense_vector(784)) | |
label = paddle.layer.data( | |
name='label', type=paddle.data_type.integer_value(10)) | |
hidden = paddle.layer.fc(input=images, size=200) | |
num_classes = 10 | |
cost = paddle.layer.hsigmoid(input=hidden, label=label, | |
num_classes=num_classes, | |
param_attr=paddle.attr.Param( | |
name='sigmoid_w'), | |
bias_attr=paddle.attr.Param(name='sigmoid_b')) | |
with paddle.layer.mixed(size=num_classes - 1, | |
act=paddle.activation.Sigmoid(), | |
bias_attr=paddle.attr.Param( | |
name='sigmoid_b')) as prediction: | |
prediction += paddle.layer.trans_full_matrix_projection(input=hidden, | |
param_attr=paddle.attr.Param( | |
name='sigmoid_w')) | |
parameters = paddle.parameters.create([cost, prediction]) | |
optimizer = paddle.optimizer.Momentum( | |
learning_rate=0.1 / 128.0, | |
momentum=0.9, | |
regularization=paddle.optimizer.L2Regularization(rate=0.0005 * 128)) | |
trainer = paddle.trainer.SGD(cost=cost, | |
parameters=parameters, | |
update_equation=optimizer) | |
test_set = list(paddle.dataset.mnist.test()()) | |
test_imgs = [(x[0],) for x in test_set] | |
test_lbls = [x[1] for x in test_set] | |
def event_handler(event): | |
if isinstance(event, paddle.event.EndPass): | |
with gzip.open('params.tar.gz', 'w') as f: | |
parameters.to_tar(f) | |
rst = paddle.infer(output_layer=prediction, | |
parameters=parameters, | |
input=test_imgs) | |
rst = rst > 0.5 | |
correct_count = 0 | |
for i, lbl in enumerate(test_lbls): | |
# very naive decode h-sigmoid tree. If the performance is bad, | |
# we could move it to C/C++ part of Paddle. | |
idx = 0 | |
result = 1 | |
vec = rst[i] | |
while idx < len(vec): | |
result <<= 1 | |
if vec[idx]: | |
result |= 1 | |
if vec[idx]: | |
idx = idx * 2 + 2 | |
else: | |
idx = idx * 2 + 1 | |
prediction_label = result - num_classes | |
if prediction_label == lbl: | |
correct_count += 1 | |
print "Pass %d, Test Error Rate %.2f%%" % (event.pass_id, | |
(1 - float( | |
correct_count) / | |
rst.shape[0]) * 100) | |
trainer.train( | |
reader=paddle.batch( | |
paddle.reader.shuffle( | |
paddle.dataset.mnist.train(), buf_size=8192), | |
batch_size=128), | |
event_handler=event_handler, | |
num_passes=100) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment