Created
January 14, 2019 06:47
-
-
Save mahxn0/52166858a9660cd4b7da4551346e054b to your computer and use it in GitHub Desktop.
googlenet caffe
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#coding=utf-8 | |
import numpy as np | |
import cv2 | |
import matplotlib.pyplot as plt | |
import os | |
import PIL | |
from PIL import Image | |
caffe_root = '/home/mahxn0/caffe/' | |
work_dir='/home/mahxn0/workspace/src/detectAndRecog/src/caffenet/model-googlenet-watch/' | |
import sys | |
sys.path.insert(0,caffe_root+'python') | |
import caffe | |
MODEL_FILE =work_dir+'deploy.prototxt' | |
PRETRAINED =work_dir+'bvlc_googlenet.caffemodel' | |
#gpu模式 | |
caffe.set_mode_gpu() | |
#定义使用的神经网络模型 | |
net = caffe.Classifier(MODEL_FILE,PRETRAINED, | |
mean=np.load(work_dir +'mean.npy').mean(1).mean(1), | |
channel_swap=(2,1,0), | |
raw_scale=255, | |
image_dims=(224, 224)) | |
imagenet_labels_filename = work_dir+'label.txt' | |
labels =np.loadtxt(imagenet_labels_filename, str, delimiter='\t') | |
#对目标路径中的图像,遍历并分类 | |
if __name__ == "__main__": | |
inputdir='/home/mahxn0/Downloads/train/true/' | |
outputdir='/media/mahxn0/DATA/dataset/test_output_googlenet/' | |
i=0 | |
font=cv2.FONT_HERSHEY_SCRIPT_COMPLEX | |
for filename in os.listdir(inputdir): | |
#加载要分类的图片 | |
srcFile=os.path.join(inputdir,filename) | |
dstFile=os.path.join(outputdir,filename) | |
input_image = cv2.imread(srcFile) | |
#预测图片类别 | |
image = np.array(input_image) | |
print('imge type',type(image)) | |
#print(image[0]) | |
#print(image) | |
image = cv2.resize(image, (256, 256)) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image = image / 255. | |
prediction = net.predict([image]) | |
print('predicted class:',prediction[0].argmax()) | |
cls=prediction[0].argmax() | |
# 输出概率最大的前5个预测结果 | |
#top_k = net.blobs['prob'].data[0].flatten().argsort()[-1:-6:-1] | |
#print labels[top_k] | |
i+=1 | |
cv2.putText(image,str(cls),(50,100), font,3,(0,0,255),5) | |
cv2.imshow("classify_result",image) | |
cv2.imwrite(outputdir+str(i)+'.jpg',image) | |
cv2.waitKey(30) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment