Created
September 5, 2023 04:51
-
-
Save z1nc0r3/e213d7ecaf7fe8f3f88b33c61d8f2f9d to your computer and use it in GitHub Desktop.
One vs All SVM algorithm using Matlab
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function accuracy = ovasvm(data, n) | |
data = shuffleData(data); | |
[train, test] = splitData(data); | |
[trainSet, valSet] = splitData(train); | |
[trainSet, valSet] = scaleData(trainSet, valSet); | |
A = 1:n; | |
C = [2^-10 2^-9 2^-8 2^-7 2^-6 2^-5 2^-4 2^-3 2^-2 2^-1 2^0 2^1 2^2 2^3 2^4 2^5 2^6 2^7 2^8 2^9 2^10]; | |
accuracy = []; | |
for i = 1:length(C) | |
options = svmlopt('C', C(i), 'Verbosity', 0); | |
predict = []; | |
for class = 1:n | |
Model = ['Model', int2str(A(class)), 'VsAll']; | |
x = invertData(trainSet, A(class)); | |
y = x(:, end); | |
x(:, end) = []; | |
svmlwrite('SVMLTrain', x, y); | |
svm_learn(options, 'SVMLTrain', Model); | |
clear SVMLTrain x y; | |
ModelOutput = ['ModelOutput', int2str(A(class)), 'VsAll']; | |
xVal = invertData(valSet, class); | |
yVal = xVal(:, end); | |
xVal(:, end) = []; | |
svmlwrite('SVMLVal', xVal, yVal); | |
svm_classify(options, 'SVMLVal', Model, ModelOutput); | |
svmpredict = svmlread(ModelOutput); | |
predict = [predict, svmpredict]; | |
end | |
accuracy(i) = WinnerTakesAll(valSet, predict, A); | |
end | |
[elt, ind] = max(accuracy); | |
cOpt = C(ind); | |
display(cOpt); | |
% Testing using optimal C value | |
options = svmlopt('C', cOpt, 'Verbosity', 0); | |
predict = []; | |
for class = 1:n | |
Model = ['Model', int2str(A(class)), 'VsAll']; | |
x = invertData(train, A(class)); | |
y = x(:, end); | |
x(:, end) = []; | |
svmlwrite('SVMLTrain', x, y); | |
svm_learn(options, 'SVMLTrain', Model); | |
ModelOutput = ['Model', int2str(A(class)), 'VsAll']; | |
xTest = invertData(test, A(class)); | |
yTest = xTest(:, end); | |
xTest(:, end) = []; | |
svmlwrite('SVMLTest', xTest, yTest); | |
svm_classify(options, 'SVMLTest', Model, ModelOutput); | |
svmpredict = svmlread(ModelOutput); | |
predict = [predict, svmpredict]; | |
end | |
accuracy = WinnerTakesAll(test, predict, A); | |
display(accuracy); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment