Last active
January 29, 2020 16:46
-
-
Save dvsseed/6f9db80edaf3a62c6507aa338959d3fb to your computer and use it in GitHub Desktop.
To do the CaffeNet Predict with Python2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This Python file uses the following encoding: utf-8 | |
import numpy as np | |
import sys, os | |
import time | |
import matplotlib.pyplot as plt | |
import caffe | |
# 開始計時 | |
since = time.time() | |
# 設定目前的工作環境在 caffe 目錄下 | |
caffe_root = 'C:/Users/user/caffe-windows/' | |
# 新增 caffe/python 到目前的環境 | |
sys.path.insert(0, caffe_root + 'python') | |
# 切換工作目錄 | |
os.chdir(caffe_root) | |
# 設定Caffe網路結構 | |
net_file = caffe_root + 'models/bvlc_reference_caffenet/deploy.prototxt' | |
# 新增訓練後的引數 | |
# 請自行下載, 地址:http://dl.caffe.berkeleyvision.org/bvlc_reference_caffenet.caffemodel | |
caffe_model = caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel' | |
# 均值檔案 | |
mean_file = caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy' | |
# 新增上述兩變數,建構一個Caffe Net | |
net = caffe.Net(net_file, caffe_model, caffe.TEST) | |
# 得到data的shape,此圖片是由預設 matplotlib 底層載入的 | |
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) | |
# matplotlib 載入的image是[0-1]像素, 圖片的資料格式[weight, high, channels], RGB | |
# caffe 載入的image是[0-255]像素, 圖片的資料格式[channels, weight, high], BGR | |
# 需要轉換, channels 放到前面 => (channels=2, weight=0, high=1) | |
transformer.set_transpose('data', (2, 0, 1)) | |
transformer.set_mean('data', np.load(mean_file).mean(1).mean(1)) | |
# 將圖片像素放大到[0-255] | |
transformer.set_raw_scale('data', 255) | |
# 原始圖片RGB格式 => Caffe圖片BGR格式 轉換 | |
transformer.set_channel_swap('data', (2, 1, 0)) | |
# 讀取範例貓圖片 | |
im = caffe.io.load_image(caffe_root + 'examples/images/cat.jpg') | |
# 圖片預處理 | |
net.blobs['data'].data[...] = transformer.preprocess('data', im) | |
# 使用 GPU-0 | |
caffe.set_mode_gpu() | |
caffe.set_device(0) | |
# 網路開始向前傳播 | |
out = net.forward() | |
# 計算耗時 | |
time_elapsed = time.time() - since | |
minute = time_elapsed / 60 | |
seconds = time_elapsed % 60 | |
# 畫分隔線 | |
print('=' * 55) | |
print('Elapsed time is: %.4f minute : %.4f seconds' % (minute, seconds)) | |
# 載入標籤label檔 | |
imagenet_labels_filename = caffe_root + 'data/ilsvrc12/synset_words.txt' | |
labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\t') | |
# 輸出結果: 載入的圖片是屬於那一種分類的機率(列表表示) | |
output_prob = out['prob'][0] | |
# 找出其中 最有可能=最大 分類的機率 | |
outmax = output_prob.argmax() | |
print('Predicted class is: [%d => %s]' % (outmax, labels[outmax])) | |
print('=' * 55) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment