Skip to content

Instantly share code, notes, and snippets.

@nicococo
Created May 4, 2012 15:37
Show Gist options
  • Save nicococo/2595569 to your computer and use it in GitHub Desktop.
Save nicococo/2595569 to your computer and use it in GitHub Desktop.
Shogun Structured Output Toolbox Interface
class CVanillaStructuredOutputMachine : public CMachine {
// Constructor, Destructor
CStructuredOutputMachine(
CStructuredData* trainset,CLoss* loss,CModel* model)
virtual ~CStructuredOutputMachine
// heritable data
CStructuredData* trainset
CLoss* loss
CModel* model
// nice to have \dots
// simple solution: deliver zeros or rand
virtual vector presolver() {
return zeros(trainset->get_dimensionality(),1)
}
// application specific methods
virtual void init_op()
virtual CResultSet compute_argmax(int index,vector w);
// vanilla SO-SVM training
void train() {
init_op
// assume diagonal regularization matrix with just one value
lambda = C(1,1)
w = presolver
List results = empty
repeat
// amount of constraints generated so far
lens = length(results)
for (i=0;i<trainset->get_size);i++) {
res = compute_argmax(i,w)
if (res->delta+ w*res->phi_pred >
max_j(results_i(j)->delta+w*results_i(j)->phi_pred) {
results_i->add(res)
}
}
// Solve QP
until lens&=&length(results)
}
};
struct CResultSet {
vector phi_example
vector phi_pred
double delta
}
class CStructuredData {
// class containing the structured data, e.g. sequences,
// trees, \dots
int get_size()
int get_dimensionality()
virtual generic get_example(int i)
virtual generic get_label(int i)
}
class CModel {
// Containing the application specific model.
// e.g. State model for HMM-SVM.
CDeltaLoss* get_delta_func();
}
class CLoss {
bool is_smooth
bool is_convex
// loss: Re -> Re^+ ?
bool is_positive
double calc(double x)
// derivatives missing
}
class CDeltaLoss {
// Application specific loss.
double calc(generic y_sol,y_pred)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment