Skip to content

Instantly share code, notes, and snippets.

@spaghetti-source
Created May 26, 2013 03:51
Show Gist options
  • Save spaghetti-source/5651656 to your computer and use it in GitHub Desktop.
Save spaghetti-source/5651656 to your computer and use it in GitHub Desktop.
Support Vector Machine (gradient descent with log-sum-exp approximation)
//
// Support Vector Machine
// (gradient descent with log-sum-exp approximation)
//
// original:
// G(w) := min_{+,-} { <w, x_+> - <w,x_-> } --> maximize
// approx:
// G'(w) := -log sum_{+,-} exp (-<w, x_+> - <w,x_->) --> maximize
// <=>
// L(w) := sum_{+,-} exp -<w, x_+ - x_-> --> minimize
//
// algoritihm:
// w = w - dL(w)
// where
// dL(w) = sum_{+,-} (x_+ - x_-) exp -<w, x_+ - x_->
//
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <map>
#include <vector>
#include <cstring>
#include <functional>
#include <algorithm>
using namespace std;
#define ALL(c) c.begin(), c.end()
#define FOR(i,c) for(typeof(c.begin())i=c.begin();i!=c.end();++i)
#define REP(i,n) for(int i=0;i<n;++i)
#define fst first
#define snd second
double randn() {
double r = 0;
for (int i = 0; i < 12; ++i)
r += 1.0 * rand() / RAND_MAX;
return r - 6;
}
int DIM = 100;
int SIZE = 200;
typedef vector<double> Vector;
int main() {
// dataset
Vector v(DIM);
REP(i, DIM) v[i] = randn();
vector< Vector > xp, xn;
REP(i, SIZE) {
Vector x(DIM);
REP(i, DIM) x[i] = randn();
double dot = 0;
REP(i, DIM) dot += x[i] * v[i];
if (dot > 0) xp.push_back( x );
else if (dot <-0) xn.push_back( x );
}
// learning
Vector w(DIM);
REP(i, DIM) w[i] = randn();
REP(epoch, 100) {
Vector dw(DIM);
REP(p, xp.size()) REP(q, xn.size()) {
double dot = 0;
REP(i, DIM) dot += (xp[p][i] - xn[q][i]) * w[i];
REP(i, DIM) dw[i] -= (xp[p][i] - xn[q][i]) * exp(-dot);
}
double norm = 0;
double eps = 0.001/SIZE;
REP(i, DIM) {
w[i] -= eps * dw[i];
norm = norm + w[i] * w[i];
}
norm = sqrt(norm);
REP(i, DIM) {
w[i] /= norm;
}
// verify
int miss = 0;
REP(p, xp.size()) {
double dot = 0;
REP(i, DIM) dot += xp[p][i] * w[i];
if (dot < 0) ++miss;
}
REP(p, xn.size()) {
double dot = 0;
REP(i, DIM) dot += xn[p][i] * w[i];
if (dot > 0) ++miss;
}
printf(" miss: %5d\n", miss);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment