Skip to content

Instantly share code, notes, and snippets.

@dvsseed
Last active January 29, 2020 16:46
Show Gist options
  • Save dvsseed/6f9db80edaf3a62c6507aa338959d3fb to your computer and use it in GitHub Desktop.
Save dvsseed/6f9db80edaf3a62c6507aa338959d3fb to your computer and use it in GitHub Desktop.
To do the CaffeNet Predict with Python2
# This Python file uses the following encoding: utf-8
import numpy as np
import sys, os
import time
import matplotlib.pyplot as plt
import caffe
# 開始計時
since = time.time()
# 設定目前的工作環境在 caffe 目錄下
caffe_root = 'C:/Users/user/caffe-windows/'
# 新增 caffe/python 到目前的環境
sys.path.insert(0, caffe_root + 'python')
# 切換工作目錄
os.chdir(caffe_root)
# 設定Caffe網路結構
net_file = caffe_root + 'models/bvlc_reference_caffenet/deploy.prototxt'
# 新增訓練後的引數
# 請自行下載, 地址:http://dl.caffe.berkeleyvision.org/bvlc_reference_caffenet.caffemodel
caffe_model = caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'
# 均值檔案
mean_file = caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy'
# 新增上述兩變數,建構一個Caffe Net
net = caffe.Net(net_file, caffe_model, caffe.TEST)
# 得到data的shape,此圖片是由預設 matplotlib 底層載入的
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
# matplotlib 載入的image是[0-1]像素, 圖片的資料格式[weight, high, channels], RGB
# caffe 載入的image是[0-255]像素, 圖片的資料格式[channels, weight, high], BGR
# 需要轉換, channels 放到前面 => (channels=2, weight=0, high=1)
transformer.set_transpose('data', (2, 0, 1))
transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))
# 將圖片像素放大到[0-255]
transformer.set_raw_scale('data', 255)
# 原始圖片RGB格式 => Caffe圖片BGR格式 轉換
transformer.set_channel_swap('data', (2, 1, 0))
# 讀取範例貓圖片
im = caffe.io.load_image(caffe_root + 'examples/images/cat.jpg')
# 圖片預處理
net.blobs['data'].data[...] = transformer.preprocess('data', im)
# 使用 GPU-0
caffe.set_mode_gpu()
caffe.set_device(0)
# 網路開始向前傳播
out = net.forward()
# 計算耗時
time_elapsed = time.time() - since
minute = time_elapsed / 60
seconds = time_elapsed % 60
# 畫分隔線
print('=' * 55)
print('Elapsed time is: %.4f minute : %.4f seconds' % (minute, seconds))
# 載入標籤label檔
imagenet_labels_filename = caffe_root + 'data/ilsvrc12/synset_words.txt'
labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\t')
# 輸出結果: 載入的圖片是屬於那一種分類的機率(列表表示)
output_prob = out['prob'][0]
# 找出其中 最有可能=最大 分類的機率
outmax = output_prob.argmax()
print('Predicted class is: [%d => %s]' % (outmax, labels[outmax]))
print('=' * 55)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment