Skip to content

Instantly share code, notes, and snippets.

@hengzhe-zhang
Created August 4, 2017 15:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hengzhe-zhang/4404fa3203cf1b24740e4131971c7939 to your computer and use it in GitHub Desktop.
Save hengzhe-zhang/4404fa3203cf1b24740e4131971c7939 to your computer and use it in GitHub Desktop.
基于ELM的图片多分类器
# 基于ELM的图片多分类器
import os
import cv2
import hpelm
import numpy as np
class FishElm(object):
sample_num = 0 # 每种照片样本数
photo_dir = '' # 照片路径
kind_list = [] # 种类名称
def __init__(self, sample_num, photo_dir, kind_list):
self.sample_num = sample_num
self.photo_dir = photo_dir
self.kind_list = kind_list
self.elm=self.train()
'''
函数说明:
读取某一目录下的图片,并将其灰度化
由于爬取的图片存在部分图片无法下载的情况,部分编号缺失,因此需要进行尝试读取操作
参数说明:
kind(String):种类名称
from_id(String):起始图片编号
all_num(String):编号
返回值:
imglist(String):图片灰度信息列表,列表中每一个元素为一张图片的灰度信息
'''
def read_as_list(self, kind, from_id, all_num):
imglist = [] # 训练数据列表
nownum = 0 # 当前已处理的图片数量
while nownum < all_num:
file_path = os.path.join(self.photo_dir, os.path.join(kind, '{}.jpg'.format(from_id)))
from_id += 1
if os.path.exists(file_path):
try:
manimg = cv2.resize(cv2.imread(file_path, cv2.IMREAD_GRAYSCALE), (100, 100),
interpolation=cv2.INTER_CUBIC)
nownum += 1
manarray = np.array(bytearray(manimg))
imglist.append(manarray)
except:
print('图片{}读取失败'.format(from_id))
return imglist
'''
函数说明:利用ELM进行训练
返回值:
ELM:训练好的ELM
'''
def train(self):
elm = hpelm.ELM(10000, len(self.kind_list))
# 生成训练数据
input_data = []
output_data = []
pos = 0 # 当前处理的种类编号
for kind in self.kind_list:
input_data.extend(self.read_as_list(kind, 0, self.sample_num))
for sample_temp in range(self.sample_num):
output_data.append([i == pos for i in range(len(self.kind_list))])
pos += 1
# 添加神经元
elm.add_neurons(30, 'lin')
elm.add_neurons(15, 'rbf_linf')
# 训练
elm.train(np.array(input_data), np.array(output_data))
return elm
'''
函数说明:利用ELM进行预测
参数:
kind_name:图片种类名称
photo_num:预测图片编号
elm:训练好的ELM
返回值:
(图片种类编号,概率)
异常情况:
图片读取异常返回(-1,-1)
'''
def predict_photo(self, kind_name, photo_num):
try:
# 准备测试数据
testim = cv2.resize(
cv2.imread(os.path.join(self.photo_dir, os.path.join(kind_name, '{}.jpg'.format(photo_num))),
cv2.IMREAD_GRAYSCALE),
(100, 100),
interpolation=cv2.INTER_CUBIC)
testarray = np.array(bytearray(testim))
test_data = np.array([testarray.tolist()])
# 预测
predict = self.elm.predict(test_data)
prbobility = 0 # 照片为某一种类的可能性
max_photo = 0 # 最可能的照片种类
for i in range(len(predict[0])):
if prbobility < predict[0][i]:
prbobility = predict[0][i]
max_photo = i
return (max_photo + 1, prbobility)
except:
return (-1, -1)
'''
函数说明:性能测试
返回值:
begin_num:参与性能测试的图片起始位置
end_num:参与性能测试的图片终止位置
'''
def test_performance(self, begin_num, end_num):
pos = 1 # 记录当前处理的图片种类编号
true_num = 0
for kind in self.kind_list:
for i in range(begin_num, end_num + 1):
ans = self.predict_photo(kind, i)
if ans == (-1, -1):
continue
if ans[0] == pos:
true_num += 1
pos += 1
return true_num / ((end_num - begin_num + 1) * len(self.kind_list))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment