Skip to content

Instantly share code, notes, and snippets.

@vtta
Last active May 1, 2019 18:17
Show Gist options
  • Save vtta/61b4939109c8e104cd8744a87a870143 to your computer and use it in GitHub Desktop.
Save vtta/61b4939109c8e104cd8744a87a870143 to your computer and use it in GitHub Desktop.
A simple non-linear perceptron classifier
LEARNING_RATE = 0.1;
MAX_ITERATIONS = 1e3;
MIN_ERROR = 1e-3;
%load input and target
perceptron_data;
[numInst, numDims] = size(input);
numClasses = size(target,2);
weights = randn(numClasses, 3*numDims);
isDone = false;
iter = 0;
while ~isDone
iter = iter + 1;
err = zeros(numInst,numClasses);
for i=1:numInst
output = ( weights * [1;input(i,:)' ;...
( input(i,1).*input(i,2) )' ; (input(i,:).^2)' ] >= 0 );
err(i,:) = target(i,:)' - output;
weights = weights + ...
LEARNING_RATE * err(i,:)' * ...
[1 input(i,:) input(i,1).*input(i,2) input(i,:).^2];
end
rmse = sqrt(sum(err.^2,1)/numInst);
if ( iter >= MAX_ITERATIONS || all(rmse < MIN_ERROR) )
isDone = true;
end
end
[~,group] = max([target ~target],[],2);
gscatter(input(:,1), input(:,2), group), hold on
xLimits = get(gca,'xlim');
yLimits = get(gca,'ylim');
for i=1:numClasses
ezplot(sprintf(...
'%f + %f*x + %f*y + %f*x*y + %f*x.^2 + %f*y.^2', ...
weights(i,:)), xLimits, yLimits)
end
title('Perceptron decision boundaries')
hold off
% -------------------------------------------------------------------
% Generated by MATLAB on 15-Apr-2019 00:52:40
% MATLAB version: 9.5.0.944444 (R2018b)
% -------------------------------------------------------------------
input = ...
[-0.639 -0.574;
-0.355 0.251;
0.872 -0.024;
-0.983 -0.108;
-0.249 0.897;
-0.589 0.128;
0.662 0.266;
0.224 0.85;
-0.831 -0.527;
0.614 0.095;
0.028 0.319;
0.158 -0.036;
-0.545 -0.657;
0.061 0.663;
-0.363 -0.077;
-0.016 0.584;
0.544 0.057;
0.036 -0.728;
0 0.872;
0.732 -0.254;
-0.033 -0.132;
-0.358 -0.691;
-0.759 0.024;
0.586 0.131;
-0.552 0.488;
0.552 0.337;
0.612 -0.073;
-0.076 0.008;
-0.581 -0.569;
0.45 -0.799;
0.619 0.063;
0.16 -0.728;
0.292 -0.771;
-0.857 0.351;
-0.484 0.451;
0.229 -0.592;
-0.425 0.225;
0.483 0.733;
-0.579 0.717;
0.152 -0.326;
-0.671 -0.539;
-0.595 -0.293;
0.083 0.656;
0.697 0.618;
0.53 -0.205;
-0.918 0.38;
0.517 0.153;
0.41 0.294;
-0.976 0.009;
-0.935 -0.084;
-0.444 -0.366;
0.337 0.683;
0.063 -0.578;
0.272 0.492;
0.254 -0.356;
0.22 -0.149;
0.467 0.738;
0.196 0.669;
0.666 -0.245;
0.109 -0.7;
0.103 0.45;
0.696 0.064;
0.008 -0.393;
0.187 0.714;
0.971 -0.072;
-0.317 0.259;
0.37 0.249;
0.437 -0.075;
0.061 -0.192;
-0.105 -0.89;
0.78 0.502;
-0.252 -0.613;
-0.111 -0.796;
-0.098 -0.58;
-0.964 -0.244;
-0.561 -0.317;
0.868 0.231;
0.245 0.447;
0.353 0.514;
-0.619 -0.221;
0.213 0.113;
0.114 0.277;
0.013 0.783;
0.748 0.233;
0.583 -0.21;
-0.519 -0.424;
-0.042 0.219;
-0.651 -0.368;
-0.051 -0.759;
0.391 0.798;
0.825 -0.401;
0.293 -0.302;
0.729 0.257;
0.046 -0.302;
-0.109 0.023;
0.279 -0.292;
0.575 0.529;
0.484 -0.764;
-0.344 -0.293;
0.743 -0.329];
target = [1; 0; 0; 1; 0; 0; 1; 1; 1; 1; 1; 0; 1; 1; 1; 0; 1; 0; 1; 0; 1; ...
1; 0; 1; 0; 1; 0; 0; 1; 0; 1; 0; 0; 0; 0; 0; 0; 1; 0; 0; 1; 1; ...
1; 1; 0; 0; 1; 1; 0; 1; 1; 1; 0; 1; 0; 0; 1; 1; 0; 0; 1; 1; 0; ...
1; 0; 0; 1; 0; 0; 1; 1; 1; 1; 1; 1; 1; 1; 1; 1; 1; 1; 1; 1; 1; ...
0; 1; 0; 1; 1; 1; 0; 0; 1; 0; 0; 0; 1; 0; 1; 0];
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment