Skip to content

Instantly share code, notes, and snippets.

@LaurentBerger
Last active February 17, 2017 14:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save LaurentBerger/0e08ad4ecfe1e0df5f4109095adb478f to your computer and use it in GitHub Desktop.
Save LaurentBerger/0e08ad4ecfe1e0df5f4109095adb478f to your computer and use it in GitHub Desktop.
#include "opencv2/opencv.hpp"
#include <iostream>
#include <fstream>
#include <ctype.h>
using namespace cv;
using namespace std;
int main(int argc, char* argv[])
{
double varX=2,varY=4;
float scale=10;
int nbData=400;
Mat h=(Mat_<float>(2,3)<<scale,0,7*varX*scale,0,-scale, 7 * varY*scale);
Mat groups;
Mat samples(nbData, 2, CV_32F);
RNG r;
Mat img(int(2*h.at<float>(1,2)),int(2*h.at<float>(0, 2) ),CV_8UC3,Scalar::all(0));
Rect rImg(0,0,img.cols,img.rows);
int nbClasse=1;
for (int i = 0; i < nbData; ++i)
{
int ind=0;
if (i < nbData )
{
samples.at<float>(i, 0) = r.gaussian(varX);
samples.at<float>(i, 1) = r.gaussian(varY);
ind=0;
}
else
{
samples.at<float>(i, 0) = r.gaussian(varX/2)+5;
samples.at<float>(i, 1) = r.gaussian(varY/2)+5;
ind=1;
nbClasse=2;
}
groups.push_back(ind);
Mat ph(3,1,CV_32F,Scalar(1));
ph(Range(0,2),Range(0,1))=samples.row(i).t();
Mat p=h*ph;
if (rImg.contains(Point(p)))
if (ind)
circle(img, Point(p.at<float>(0, 0), p.at<float>(1, 0)), 2, Scalar(0, 0, 255), 2);
else
circle(img, Point(p.at<float>(0, 0), p.at<float>(1, 0)), 2, Scalar(255, 0, 255), 2);
}
cout << samples << endl;
cout << groups << endl;
Ptr<ml::SVM> classifierSVM = ml::SVM::create();
if (nbClasse==2)
classifierSVM->setType(ml::SVM::C_SVC);
else
classifierSVM->setType(ml::SVM::ONE_CLASS);
classifierSVM->setKernel(ml::SVM::LINEAR);
classifierSVM->setDegree(3);
classifierSVM->setGamma(1);
classifierSVM->setCoef0(0);
classifierSVM->setC(1);
classifierSVM->setNu(0.1);
classifierSVM->setP(0);
classifierSVM->setTermCriteria(cvTermCriteria(CV_TERMCRIT_ITER + CV_TERMCRIT_EPS,500, FLT_EPSILON));
classifierSVM->train(samples, ml::ROW_SAMPLE, groups);
classifierSVM->save("trainingData.yml");
for (int i = 0; i < 20; ++i) {
Mat test(1, 2, CV_32F);
if (i%2==0)
{
test.at<float>(0, 0) = float(i-10);
test.at<float>(0, 1) = float(i-10);
}
else
test = samples.row(i);
float result = classifierSVM->predict(test);
cout << test << ": class " << result << endl;
Mat ph(3, 1, CV_32F, Scalar(1));
ph(Range(0, 2), Range(0, 1)) = test.t();
Mat p = h*ph;
if (rImg.contains(Point(p)))
{
if (result==0)
rectangle(img, Rect(p.at<float>(0, 0) - 1, p.at<float>(1, 0) - 1, 3, 3), Scalar(0, 255, 0), 2);
else
rectangle(img, Rect(p.at<float>(0, 0) - 1, p.at<float>(1,0) - 1, 3, 3), Scalar(255, 255, 255), 2);
}
}
imshow("SVM",img);
waitKey(0);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment