Created
June 2, 2017 08:03
-
-
Save ytbilly3636/8bdcf7683b06b7fc75db66d6b2679261 to your computer and use it in GitHub Desktop.
caffe2用
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# caffe2のDB作成に必要 | |
import numpy as np | |
from StringIO import StringIO | |
from caffe2.python import core, utils, workspace | |
from caffe2.proto import caffe2_pb2 | |
# 今回はchainerのデータセットを拝借する | |
from chainer import datasets | |
train, test = datasets.get_cifar10() | |
# preference | |
DB_TYPE = "minidb" | |
# DBの本体 | |
db_train = core.C.create_db(DB_TYPE, 'cifar_train.' + DB_TYPE, core.C.Mode.write) | |
trans = db_train.new_transaction() | |
for i in xrange(train._length): | |
# caffe2_pb2.TensorProtos() を使って | |
# NumPy配列をCaffe2Tensorに変換する | |
# numpy.int32(train[i][1]の型)は未対応らしいのでnumpy.float32に変換した | |
image_and_label = caffe2_pb2.TensorProtos() | |
image_and_label.protos.extend([ | |
utils.NumpyArrayToCaffe2Tensor(train[i][0]), | |
utils.NumpyArrayToCaffe2Tensor(train[i][1].astype(np.float32)) | |
]) | |
# Caffe2Tensorを文字列に変換してDBに突っ込む | |
trans.put('train_%d'.format(i), image_and_label.SerializeToString()) | |
del trans | |
del db_train |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment