Skip to content

Instantly share code, notes, and snippets.

@ibogun
Last active August 29, 2015 13:58
Show Gist options
  • Save ibogun/10010417 to your computer and use it in GitHub Desktop.
Save ibogun/10010417 to your computer and use it in GitHub Desktop.
Winner-takes-it-all multiclass svm using structured output SVM.
%% Tutorial on Multi-class classification using structured output SVM
% This tutorial shows how multi-class classification can be cast and solved
% using structured output SVM (introduced in [1]). Structured output SVM generalizes both
% binary SVM and SVM regression as it allows to predict _structured
% objects_ which can be seen as an elements of an ordered set. Formally
% structured output learning is cast in the following way: given input
% vectors $\bf{x} \in \mathcal{X}$ learn a mapping to the response
% variables $\bf{y} \in \mathcal{Y}$. In the case of binary classification
% we have
% \begin{equation}\mathcal{Y_1}=\{-1,1\}\end{equation} while for regression
% \begin{equation}\mathcal{Y_2}=\mathbb{R}\end{equation} for multi-class
% classification
% \begin{equation}\mathcal{Y_3}=\{1,...,K\}\end{equation}
% In fact response variable can be any structured object such as
% sequence, tree etc. [1]. The function to be learnt is in the form:
% $f:\mathcal{X} \rightarrow \mathcal{Y}$. We define the latter in terms of
% _discriminant function_ as
% \begin{equation}F: \mathcal{X} \times \mathcal{Y} \rightarrow \mathbb{R}\end{equation}
% which can be seen as compatibility function between pattern
% $\bf{x}\in \mathcal{X}$ and the response $\bf{y} \in \mathcal{Y}$. Prediction
% is made by finding the most compatible response in the set $\mathcal{Y}$:
% \begin{equation} f(\mathbf{x};\mathbf{w})=\arg \max_{\mathbf{y} \in \mathcal{Y}}
% F(\mathbf{x,y};\mathbf{w})\end{equation}
% where $\mathbf{w}$ is the vector with coefficients to be learnt. Also,
% assume that function $F$ is linear in terms of _combined feature
% representation_ of patterns and responses, that is
% \begin{equation}
% F(\mathbf{x,y};\mathbf{w})=\langle\mathbf{w},\Phi(\mathbf{x,y})\rangle
% \end{equation} ( kernel trick can be applied to form joint kernel
% function
% $K((\mathbf{x},\mathbf{y}),(\mathbf{x}',\mathbf{y}'))=\langle \Phi(\mathbf{x},\mathbf{y}),\Phi(\mathbf{x}',\mathbf{y}')\rangle$)
% Variations exist how max-margin problem can be setup. Here, we follow
% slack-rescaling from [1]. Parameters $\mathbf{w}$ are found by solving
% quadratic convex optimization problem:
% \begin{align}\label{ssvm}
% \min_{\mathbf{w},\mathbf{\xi}} &= \frac{1}{2}||\mathbf{w}||^2 + \frac{C}{n} \sum_{i=1}^n \xi_i
% \\
% \text{s.t. } &\forall i, \forall \mathbf{y} \in \mathcal{Y}\setminus \mathbf{y}_i: \langle \mathbf{w}, \Psi(\mathbf{x}_i,\mathbf{y}_i)-\Psi (\mathbf{x}_i,\mathbf{y})\rangle\geq 1- \frac{\xi_i}{\Delta(\mathbf{y}_i,\mathbf{y})}
% \end{align}
% where
% \begin{equation}
% \Delta: \mathcal{Y} \times \mathcal{Y} \rightarrow \mathbb{R}
% \end{equation} is a _loss function_; for example $\Delta(y_1,y_2)$ is a
% loss associated with predicting $y_1$ while correct response was $y_2$.
%% Dependencies
% Author: Ivan Bogun
%
% Date: April 6, 2014
%
% Required:
%
% * Original SSVM solver code by Thorsten Joachims:
% <http://www.cs.cornell.edu/people/tj/svm_light/svm_struct.html ssvm-solver>
% * SSVM Matlab wrappers by Andrea Vedaldi:
% <http://www.robots.ox.ac.uk/~vedaldi/code/svm-struct-matlab.html matlab-wrappers>
%
% Optional:
%
% * Download source used to generate this tutorial:
% <https://gist.github.com/ibogun/10010417#file-multiclass-svm this file>
function tutorial_multiClass_ssvm
%% Generate data
% The following code will generate a random dataset consisting of 3
% separable classes.
tic;
patterns = {} ;
labels = {} ;
% number of patterns
N=200;
% dimension
d=2;
% set seeds for reproducibility
randn('state',0) ;
rand('state',0) ;
training_data=randn(d,N)*10;
group1=training_data(2,:)<0;
negX_idx=training_data(1,:)<0;
group2=~group1 & negX_idx;
group3=~group1 & ~negX_idx;
labels=zeros(1,N);
labels(group1)=1;
labels(group2)=2;
labels(group3)=3;
%get minimum and maximum values of the data for plotting
xmin=round(min(training_data(1,:)));
xmax=round(max(training_data(1,:)));
ymin=round(min(training_data(2,:)));
ymax=round(max(training_data(2,:)));
patterns=num2cell(training_data,1);
labels=num2cell(labels);
%plot the data
figure;
scatter(training_data(1,group1),training_data(2,group1),8,'filled','b');
hold on;
scatter(training_data(1,group2),training_data(2,group2),8,'filled','g');
hold on;
scatter(training_data(1,group3),training_data(2,group3),8,'filled','r');
title('Training data');
hold off;
legend('class 1','class 2','class 3');
snapnow;
%% Feature function for SSVM
% Feature mapping $\Psi: X \rightarrow Y$. Feature mapping for binary SVM is given by
% $\Psi_1(x,y)=(xy/2)$. We generalize it to multi-class classification
% using the following transformation:
% \begin{equation}
% \Psi(\mathbf{x},\mathbf{y})=[\frac{xy_1}{2},\cdots, \frac{xy_K}{2}]
% \end{equation}
% where
% \[y_i = \left\{
% \begin{array}{lr}
% 1 & : y=i\\
% -1 & : y \neq i
% \end{array}
% \right.
% \]
function psi = featureCB(param, x, y)
d=-1*ones(1,param.dimension);
newX=[x;x;x];
% possible labels: y=1,2,3 and dimension of x is 2 -> multiply
% appropriate stack to 1
d(2*y-1:2*y)=1;
psi = sparse(d'.*newX/2) ;
if param.verbose
fprintf('w = psi([%8.3f,%8.3f], %3d) = [%8.3f, %8.3f]\n', ...
x, y, full(psi(1)), full(psi(2))) ;
end
end
%% Constraint function for SSVM
% Problem \ref{ssvm} might potentially have exponentially many constraints,
% which is infeasible even for trivial cases. Solution is found by relaxing
% the problem (cutting-plane algorithms): instead of trying to satisfy all possible constraints
% on every iteration add the most violated one and continue. Practically this
% requires a function which would be able to find such a constraint. For
% slack rescaling variation we need to find:
% \begin{equation} y'=\arg \max_{y \in \mathcal{Y}}
% \Delta(y,y_i)(1+\langle\Psi(x,y;w),w\rangle-\langle \Psi(x,y_i;w),w\rangle)\end{equation}
% which is easy since $|\mathcal{Y}|=3$.
function yhat = constraintCB(param, model, x, y)
% slack resaling: argmax_y delta(yi, y) (1 + <psi(x,y), w> - <psi(x,yi), w>)
% margin rescaling: argmax_y delta(yi, y) + <psi(x,y), w>
yhat=1;
H_y_hat_best=-Inf;
for j=1:3
H_y=lossCB(param,y,j)*(1-dot(featureCB(param, x, y)-...
featureCB(param, x, j),model.w));
if (H_y_hat_best<H_y)
H_y_hat_best=H_y;
yhat=j;
end
end
if param.verbose
fprintf('yhat = violslack([%8.3f,%8.3f], [%8.3f,%8.3f], %3d) = %3d\n', ...
model.w, x, y, yhat) ;
end
end
%% Loss function
% For classification we are using $0-1$ loss function defined as
% \[\Delta(y_i,y_j) = \left\{
% \begin{array}{lr}
% 1 & : y_i\neq y_j\\
% 0 & : y_i= y_j
% \end{array}
% \right.
% \]
function delta = lossCB(param, y, ybar)
delta = double(y ~= ybar) ;
if param.verbose
fprintf('delta = loss(%3d, %3d) = %f\n', y, ybar, delta) ;
end
end
%% Prediction function
% Define discriminative function $F(x,y;w)=\langle w,\Psi(x,y) \rangle$ and use it to
% predict response variable. This choice is made by maximizing over y:
% \begin{equation}f(x;w)=\arg \max_{y \in \mathcal{Y}} F(x,y;w)\end{equation}
function argmax_y=predictCB(param,model,x)
% for three classes
y=1:3;
F=@(z) dot(model.w,featureCB(param,x,z));
best=0;
for j=1:3
v=F(y(j));
if (j==1)
best=v;
argmax_y=y(j);
end
if (v>best)
best=v;
argmax_y=y(j);
end
end
end
%% Run SVM struct
% Setup SSVM parameters and run it.
% parameters
% | -c 10.0 -o 1 -v 1 |
% mean $C=10$ ssvm with slack
% rescaling and verbose level 1.
parm.patterns = patterns ;
parm.labels = labels ;
parm.lossFn = @lossCB ;
parm.constraintFn = @constraintCB ;
parm.featureFn = @featureCB ;
parm.dimension = 6 ;
parm.verbose = 0 ;
model = svm_struct_learn(' -c 10.0 -o 1 -v 1 ', parm) ;
w = model.w ;
% takes some time ( set to 1 if takes too long)
step=0.1;
[x,y]=meshgrid(xmin:step:xmax,ymin:step:ymax);
[n,m]=size(x);
C=zeros(n,m);
patterns_test=num2cell([x(:)';y(:)'],1);
N=size(patterns_test,2);
labels_test=zeros(N,1);
for i=1:N
[idx,idy]=ind2sub([n,m],i);
labels_test(i)=predictCB(parm,model,patterns_test{i});
end
%% Plot results
% The following code will plot separating hyperplanes for each class as
% well as how space is being separated by the learnt functtion, $f$.
v=1000;
set(line(v*[w(2) -w(2)], v*[-w(1) w(1)]), ...
'color', 'b', 'linewidth', 1, 'linestyle', '-') ;
set(line(v*[w(4) -w(4)], v*[-w(3) w(3)]), ...
'color', 'g', 'linewidth', 1, 'linestyle', '-') ;
set(line(v*[w(6) -w(6)], v*[-w(5) w(5)]), ...
'color', 'r', 'linewidth', 1, 'linestyle', '-') ;
title('Training data with separating hyperplanes');
axis([xmin,xmax,ymin,ymax]);
g1=(labels_test==1);
g2=(labels_test==2);
g3=(labels_test==3);
c1=[patterns_test{g1}];
c2=[patterns_test{g2}];
c3=[patterns_test{g3}];
getConvHull=@(s) s(:,convhull(s(1,:),s(2,:)));
v1=getConvHull(c1);
v2=getConvHull(c2);
v3=getConvHull(c3);
figure;
fill(v1(1,:),v1(2,:),'b');
hold on;
fill(v2(1,:),v2(2,:),'g');
hold on;
fill(v3(1,:),v3(2,:),'r');
hold on;
axis([xmin,xmax,ymin,ymax]);
title('Space separated by discriminative function');
%% References
% [1] I. Tsochantaridis, T. Hofmann, T. Joachims, and Y. Altun. Support Vector Learning for Interdependent and Structured Output Spaces, ICML, 2004.
toc;
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment