Created
April 10, 2018 04:49
-
-
Save axi345/6c3e2f4042b8be4c46cddc4ba354859e to your computer and use it in GitHub Desktop.
A lightweight package for artificial intelligence
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
__version__ = "0.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import numpy as np | |
class PSO: | |
def __init__(self, func=None, param_len=1, size=5, w=0.9, c1=2., c2=2., x_min=-10., x_max=10., v_min=-0.5, | |
v_max=0.5, r1=None, r2=None): | |
self.func = func # 计算适应性系数的方法 | |
self.param_len = param_len # 参数个数 | |
self.size = size # 有多少微粒 | |
self.w = w # 惯性系数 | |
self.c1 = c1 # 认知系数 | |
self.c2 = c2 # 社会系数 | |
# 最小位移 | |
if type(x_min) != list: | |
self.x_min = [x_min] * self.param_len | |
else: | |
assert len(x_min) == self.param_len | |
self.x_min = x_min | |
# 最大位移 | |
if type(x_max) != list: | |
self.x_max = [x_max] * self.param_len | |
else: | |
assert len(x_max) == self.param_len | |
self.x_max = x_max | |
# 最小速度 | |
if type(v_min) != list: | |
self.v_min = [v_min] * self.param_len | |
else: | |
assert len(v_min) == self.param_len | |
self.v_min = v_min | |
# 最大速度 | |
if type(v_max) != list: | |
self.v_max = [v_max] * self.param_len | |
else: | |
assert len(v_max) == self.param_len | |
self.v_max = v_max | |
self.r1 = r1 # 随机数1 | |
self.r2 = r2 # 随机数2 | |
self.x = None # 位移 | |
self.v = None # 速度 | |
self.best_all_x = None # 全局最优位置 | |
self.best_all_score = None # 全局最优分数 | |
self.best_each_x = None # 局部最优位置 | |
self.best_each_score = None # 局部最优分数 | |
self._init_fit() | |
def _init_fit(self): | |
self.x = np.zeros([self.size, self.param_len])# 形状为[微粒个数,参数个数] | |
self.v = np.zeros([self.size, self.param_len]) | |
for i in range(self.size): | |
for j in range(self.param_len): | |
self.x[i][j] = np.random.uniform(self.x_min[j], self.x_max[j]) | |
self.v[i][j] = np.random.uniform(self.v_min[j], self.v_max[j]) | |
self.best_all_x = np.zeros(self.param_len) # 全局最优位置 | |
self.best_all_score = -np.inf # 全局最优分数 | |
self.best_each_x = self.x.copy() # 局部最优位置 | |
self.best_each_score = np.full(self.size, -np.inf) # 局部最优分数 | |
def solve(self, epoch=5): | |
r1 = self.r1 | |
r2 = self.r2 | |
for _ in range(epoch): # 一共迭代_次 | |
# 配置随机变量 | |
if r1 is None: | |
r1 = np.random.uniform(0, 1) | |
if r2 is None: | |
r2 = np.random.uniform(0, 1) | |
# 计算适应度 | |
for i in range(self.size): # 对于第i个微粒 | |
fitness = self.func(*self.x[i]) | |
# 更新局部最优值、局部最优位置 | |
if fitness > self.best_each_score[i]: | |
self.best_each_score[i] = fitness | |
self.best_each_x[i] = self.x[i] | |
if fitness > self.best_all_score: | |
self.best_all_score = fitness | |
self.best_all_x = self.x[i].copy() | |
# 更新微粒的速度和位置 | |
for i in range(self.size): | |
for j in range(self.param_len): | |
self.v[i][j] = self.w * self.v[i][j] + self.c1 * r1 * ( | |
self.best_each_x[i][j] - self.x[i][j]) + self.c2 * r2 * ( | |
self.best_all_x[j] - self.x[i][j]) | |
self.v[i][j] = np.clip(self.v[i][j], self.v_min[j], self.v_max[j]) | |
self.x[i][j] = self.x[i][j] + self.v[i][j] | |
self.x[i][j] = np.clip(self.x[i][j], self.x_min[j], self.x_max[j]) | |
print('已完成第%i次寻找,最优参数值为' % (_ + 1), self.best_all_x, '目前最优适合度为%.4f' % self.best_all_score) | |
return self.best_all_x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from setuptools import setup | |
PACKAGE = "ailearn" | |
NAME = "ailearn" | |
DESCRIPTION = "A lightweight package for artificial intelligence" | |
AUTHOR = "ZHAO Xingyu" | |
AUTHOR_EMAIL = "757008724@qq.com" | |
URL = "" | |
VERSION = __import__(PACKAGE).__version__ | |
setup( | |
name=NAME, | |
version=VERSION, | |
description=DESCRIPTION, | |
long_description=read("README.rst"), | |
author=AUTHOR, | |
author_email=AUTHOR_EMAIL, | |
license="BSD", | |
url=URL, | |
packages=find_packages(exclude=["tests.*", "tests"]), | |
package_data=find_package_data( | |
PACKAGE, | |
only_in_packages=False | |
), | |
classifiers=[ | |
"Development Status :: 3 - Alpha", | |
"Environment :: Web Environment", | |
"Intended Audience :: Developers", | |
"License :: OSI Approved :: BSD License", | |
"Operating System :: OS Independent", | |
"Programming Language :: Python", | |
"Framework :: Django", | |
], | |
zip_safe=False, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment