Last active
August 7, 2017 09:45
-
-
Save liangzai-cool/f93614a8ad2c826a27461aadbd1bb006 to your computer and use it in GitHub Desktop.
零基础入门深度学习(1) - 感知器
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Arrays; | |
import java.util.function.Function; | |
public class Perceptron { | |
private Function<Double, Integer> activator; | |
private double bias; | |
private double[] weights; | |
public Perceptron() { } | |
/** | |
* 初始化感知器,设置输入参数的个数,以及激活函数。 | |
* 激活函数的类型为double -> double | |
* @param inputNum | |
* @param activator | |
*/ | |
public Perceptron(int inputNum, Function<Double, Integer> activator) { | |
this.activator = activator; | |
weights = new double[inputNum]; | |
// 权重向量初始化为0 | |
for (int i = 0; i < inputNum; i++) { | |
weights[i] = 0.0; | |
} | |
// 偏置项初始化为0 | |
bias = 0.0; | |
} | |
/** | |
* 基于and真值表构建训练数据 | |
* 构建训练数据 | |
* 输入向量列表 | |
* @return | |
*/ | |
public int[][] getTrainingDatasetInputVecs() { | |
return new int[][]{{1, 1}, {0,0}, {1,0}, {0,1}}; | |
} | |
/** | |
* 基于and真值表构建训练数据 | |
* 期望的输出列表,注意要与输入一一对应 | |
* [1,1] -> 1, [0,0] -> 0, [1,0] -> 0, [0,1] -> 0 | |
* @return | |
*/ | |
public int[] getTrainingDatasetLables() { | |
return new int[]{1, 0, 0, 0}; | |
} | |
/** | |
* 输入训练数据:一组向量、与每个向量对应的label;以及训练轮数、学习率 | |
* @param inputVecs | |
* @param labels | |
* @param iteration | |
* @param rate | |
*/ | |
public void train(int[][] inputVecs, int[] labels, int iteration, double rate) { | |
for (int i = 0; i < iteration; i++) { | |
oneIteration(inputVecs, labels, rate); | |
} | |
} | |
/** | |
* 一次迭代,把所有的训练数据过一遍 | |
* @param inputVecs | |
* @param labels | |
* @param rate | |
*/ | |
public void oneIteration(int[][] inputVecs, int[] labels, double rate) { | |
// 对每个样本,按照感知器规则更新权重 | |
for (int i = 0; i < labels.length; i++) { | |
int[] inputVec = inputVecs[i]; | |
int label = labels[i]; | |
// 计算感知器在当前权重下的输出 | |
int result = predict(inputVec); | |
// 更新权重 | |
updateWeights(inputVec, result, label, rate); | |
} | |
} | |
/** | |
* 输入向量,输出感知器的计算结果 | |
* @param inputVec | |
* @return | |
*/ | |
public int predict(int[] inputVec) { | |
double d = 0.0; | |
for (int i = 0; i < inputVec.length; i++) { | |
d += (inputVec[i] * weights[i]); | |
} | |
d += bias; | |
return this.activator.apply(d); | |
} | |
/** | |
* 按照感知器规则更新权重 | |
* @param inputVec | |
* @param result | |
* @param label | |
* @param rate | |
*/ | |
public void updateWeights(int[] inputVec, int result, int label, double rate) { | |
int delta = label - result; | |
for (int i = 0; i < inputVec.length; i++) { | |
weights[i] = weights[i] + rate * delta * inputVec[i]; | |
} | |
// 更新bias | |
bias += rate * delta; | |
} | |
/** | |
* 使用and真值表训练感知器 | |
* @return | |
*/ | |
public Perceptron trainAndPerceptron() { | |
// 定义激活函数f | |
Function<Double, Integer> f = new Function<Double, Integer>() { | |
@Override | |
public Integer apply(Double t) { | |
return t > 0 ? 1 : 0; | |
} | |
}; | |
// 创建感知器,输入参数个数为2(因为and是二元函数),激活函数为f | |
Perceptron perceptron = new Perceptron(2, f); | |
int[][] inputVecs = getTrainingDatasetInputVecs(); | |
int[] labels = getTrainingDatasetLables(); | |
// 训练,迭代10轮, 学习速率为0.1 | |
perceptron.train(inputVecs, labels, 10, 0.1); | |
// 返回训练好的感知器 | |
return perceptron; | |
} | |
/** | |
* 打印学习到的权重、偏置项 | |
*/ | |
@Override | |
public String toString() { | |
return String.format("weights\t:%s\nbias\t:%f\n", Arrays.toString(weights), bias); | |
} | |
public static void main(String[] args) { | |
Perceptron perceptron = new Perceptron(); | |
// 训练and感知器 | |
perceptron = perceptron.trainAndPerceptron(); | |
// 打印训练获得的权重 | |
System.out.println(perceptron); | |
// 测试 | |
System.out.println(String.format("1 and 1 = %d", perceptron.predict(new int[]{1, 1}))); | |
System.out.println(String.format("0 and 0 = %d", perceptron.predict(new int[]{0, 0}))); | |
System.out.println(String.format("1 and 0 = %d", perceptron.predict(new int[]{1, 0}))); | |
System.out.println(String.format("0 and 1 = %d", perceptron.predict(new int[]{0, 1}))); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment