Skip to content

Instantly share code, notes, and snippets.

@ytbilly3636
Created June 2, 2017 08:03
Show Gist options
  • Save ytbilly3636/8bdcf7683b06b7fc75db66d6b2679261 to your computer and use it in GitHub Desktop.
Save ytbilly3636/8bdcf7683b06b7fc75db66d6b2679261 to your computer and use it in GitHub Desktop.
caffe2用
# -*- coding: utf-8 -*-
# caffe2のDB作成に必要
import numpy as np
from StringIO import StringIO
from caffe2.python import core, utils, workspace
from caffe2.proto import caffe2_pb2
# 今回はchainerのデータセットを拝借する
from chainer import datasets
train, test = datasets.get_cifar10()
# preference
DB_TYPE = "minidb"
# DBの本体
db_train = core.C.create_db(DB_TYPE, 'cifar_train.' + DB_TYPE, core.C.Mode.write)
trans = db_train.new_transaction()
for i in xrange(train._length):
# caffe2_pb2.TensorProtos() を使って
# NumPy配列をCaffe2Tensorに変換する
# numpy.int32(train[i][1]の型)は未対応らしいのでnumpy.float32に変換した
image_and_label = caffe2_pb2.TensorProtos()
image_and_label.protos.extend([
utils.NumpyArrayToCaffe2Tensor(train[i][0]),
utils.NumpyArrayToCaffe2Tensor(train[i][1].astype(np.float32))
])
# Caffe2Tensorを文字列に変換してDBに突っ込む
trans.put('train_%d'.format(i), image_and_label.SerializeToString())
del trans
del db_train
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment