Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
import matplotlib
# サーバー実行用matplotlib設定
matplotlib.use('Agg')
import matplotlib.pyplot as plot
import chainer
from chainercv.datasets import voc_detection_label_names
from chainercv.links import SSD512
from chainercv import utils
from chainercv.visualizations import vis_bbox
def get_model():
import os
import pickle
dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/sd512model.pkl"
if os.path.exists(save_file):
with open(save_file, 'rb') as f:
model = pickle.load(f)
else:
model = SSD512(
n_fg_class=len(voc_detection_label_names),
pretrained_model='voc0712')
with open(save_file, 'wb') as f:
pickle.dump(model, f)
return model
def main():
model = get_model()
# GPU利用
chainer.cuda.get_device(0).use()
model.to_gpu()
img = utils.read_image('images/sample.jpg', color=True)
bboxes, labels, scores = model.predict([img])
bbox, label, score = bboxes[0], labels[0], scores[0]
vis_bbox(
img, bbox, label, score, label_names=voc_detection_label_names)
plot.savefig('out.jpg')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment