Skip to content

Instantly share code, notes, and snippets.

@haoliplus
Created May 13, 2016 07:01
Show Gist options
  • Save haoliplus/59e036341344e7e7e4163573cbf1e087 to your computer and use it in GitHub Desktop.
Save haoliplus/59e036341344e7e7e4163573cbf1e087 to your computer and use it in GitHub Desktop.
Using Caffe model to predict
#! /usr/bin/env python
#################################################################################
# File Name : test-one.py
# Created By : Hao Li
# Creation Date : [2016-04-08 13:44]
# Last Modified : [2016-04-26 22:24]
# Description :
#################################################################################
import os
import sys
import numpy
sys.path.insert(0, "path-to-caffe-python-lib")
import caffe
class CaffeModel():
def convert_mean_file(self, mean_filename):
# 将binaryproto文件转化为npy文件内容
proto_data = open(mean_filename, "rb").read()
blob = caffe.proto.caffe_pb2.BlobProto()
blob.ParseFromString(proto_data)
npy = numpy.array(caffe.io.blobproto_to_array(blob))[0]
return npy
def get_mean(self, mean_filename):
# 从npy文件内容中计算平均值
npy = self.convert_mean_file(mean_filename)
return npy.mean(1).mean(1)
def __init__(self, data_name="current"):
# 初始化模型
caffe_root = "/mnt/disk0/lihao/plate-test/data/caffe-model"
MODEL_FILE = '%s/deploy.prototxt' % (caffe_root)
PRETRAINED = '%s/%s/caffe_alexnet_train_iter_2268000.caffemodel' %(caffe_root, data_name)
mean_filename = '%s/data/%s/plate_mean.binaryproto' % (caffe_root, data_name)
mean = self.get_mean(mean_filename)
caffe.set_device(0)
caffe.set_mode_gpu()
# caffe.Classifier 将 预处理和 Net 的预测封装了一下。
# 会以caffe.TEST模式来进行预测, 这里的raw_scale需要与训练模型采用的像素范围相同,
#
self.net = caffe.Classifier(MODEL_FILE, PRETRAINED,
mean=mean,
raw_scale=255,
image_dims=(64,64))
def getProb(self, img_path):
img = caffe.io.load_image(img_path, color=False)
pred = self.net.predict([img], oversample=False)
return pred
def predict(self, segments):
probs = []
imgs = [caffe.ip.load_image(path, color=False) for path in segments]
# Oversample会进行角落,中心,镜像的平均值采样
probs = self.net.predict(imgs, oversample=False)
probs = numpy.array(probs)
return probs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment