Created
May 4, 2012 15:37
-
-
Save nicococo/2595569 to your computer and use it in GitHub Desktop.
Shogun Structured Output Toolbox Interface
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CVanillaStructuredOutputMachine : public CMachine { | |
// Constructor, Destructor | |
CStructuredOutputMachine( | |
CStructuredData* trainset,CLoss* loss,CModel* model) | |
virtual ~CStructuredOutputMachine | |
// heritable data | |
CStructuredData* trainset | |
CLoss* loss | |
CModel* model | |
// nice to have \dots | |
// simple solution: deliver zeros or rand | |
virtual vector presolver() { | |
return zeros(trainset->get_dimensionality(),1) | |
} | |
// application specific methods | |
virtual void init_op() | |
virtual CResultSet compute_argmax(int index,vector w); | |
// vanilla SO-SVM training | |
void train() { | |
init_op | |
// assume diagonal regularization matrix with just one value | |
lambda = C(1,1) | |
w = presolver | |
List results = empty | |
repeat | |
// amount of constraints generated so far | |
lens = length(results) | |
for (i=0;i<trainset->get_size);i++) { | |
res = compute_argmax(i,w) | |
if (res->delta+ w*res->phi_pred > | |
max_j(results_i(j)->delta+w*results_i(j)->phi_pred) { | |
results_i->add(res) | |
} | |
} | |
// Solve QP | |
until lens&=&length(results) | |
} | |
}; | |
struct CResultSet { | |
vector phi_example | |
vector phi_pred | |
double delta | |
} | |
class CStructuredData { | |
// class containing the structured data, e.g. sequences, | |
// trees, \dots | |
int get_size() | |
int get_dimensionality() | |
virtual generic get_example(int i) | |
virtual generic get_label(int i) | |
} | |
class CModel { | |
// Containing the application specific model. | |
// e.g. State model for HMM-SVM. | |
CDeltaLoss* get_delta_func(); | |
} | |
class CLoss { | |
bool is_smooth | |
bool is_convex | |
// loss: Re -> Re^+ ? | |
bool is_positive | |
double calc(double x) | |
// derivatives missing | |
} | |
class CDeltaLoss { | |
// Application specific loss. | |
double calc(generic y_sol,y_pred) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment