-
-
Save nicococo/2634487 to your computer and use it in GitHub Desktop.
Shogun Structured Output Toolbox Interface
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CVanillaStructuredOutputMachine : public CMachine { | |
// Constructor, Destructor | |
CStructuredOutputMachine( | |
CStructuredData* trainset,CLoss* loss,CStructuredApplication* model) | |
virtual ~CStructuredOutputMachine | |
// heritable data | |
CStructuredData* trainset | |
CLoss* loss | |
CStructuredApplication* model | |
// nice to have \dots | |
// simple solution: deliver zeros or rand | |
virtual vector presolver() { | |
return zeros(model->get_dimensionality(),1) | |
} | |
// vanilla SO-SVM training | |
void train() { | |
[A,a,B,b,lb,ub,C] = init_op | |
// assume diagonal regularization matrix with just one value | |
lambda = C(1,1) | |
w = presolver() | |
List results = empty | |
repeat | |
// amount of constraints generated so far | |
lens = length(results) | |
for (i=0;i<trainset->get_size);i++) { | |
res = model.compute_argmax(i,w) | |
slack = w*res.psi_pred + res.delta - w*res.psi_truth | |
slack = loss.calc(slack) | |
if (slack > max_j(loss.calc( | |
results_i(j).delta+w*results_i(j).psi_pred | |
- results_i(j).psi_truth)) { | |
results_i->add(res) | |
} | |
} | |
// Solve QP | |
until lens&=&length(results) | |
} | |
}; | |
struct CResultSet { | |
// joint feature vector for the given truth | |
vector psi_truth | |
// joint feature vector for the prediction | |
vector psi_pred | |
// corresponding score | |
double score | |
// delta loss for the prediction vs. truth | |
double delta | |
} | |
class CStructuredData { | |
// class containing the structured data, e.g. sequences, | |
// trees, \dots | |
int get_size() | |
int get_dimensionality() | |
virtual generic get_example(int i) | |
virtual generic get_label(int i) | |
} | |
class CStructuredApplication { | |
// Containing the application specific model. | |
// e.g. for HMM-SVM | |
// - state model | |
// - viterbi | |
// - delta loss | |
// - join feature vector representation | |
// Application specific loss. | |
double delta(generic y_sol,y_pred) | |
// latent loss | |
double delta(generic y_sol,y_pred,h_hat) | |
// init the optimization problem | |
// gives back A,a,B,b,lb,ub and C | |
virtual [A,a,B,b,lb,ub,C] init_op() | |
// what is the size of w | |
virtual int get_dimensionality(); | |
// compute most likely configuration given | |
// data index and solution vector w | |
virtual CResultSet compute_argmax(CStructuredData* data,int index,vector w) { | |
// call apply | |
} | |
// returns the most likely configurations and its corresponding score | |
virtual (generic,double) apply(CStructuredData* data,int index,vector w); | |
// private members | |
// get the psi-vector from example with index i | |
virtual vector get_joint_feature_representation(CStructuredData data,int i) | |
virtual vector get_joint_feature_representation(generic x, generic y) | |
// for latent model | |
virtual vector get_joint_feature_representation(generic x, generic y, generic h) | |
} | |
class CLoss { | |
bool is_smooth | |
bool is_convex | |
// loss: Re -> Re^+ ? | |
bool is_positive | |
// calculate the loss | |
// e.g. hinge-loss: calc(z) = max(0,z) | |
double calc(double z) | |
// (sub)gradient | |
// make use of the chain rule | |
// e.g. linear classifier z = w'x | |
// dLoss/dw = [dLoss/dz] [dz/dw] | |
// first term is calculated by 'grad' | |
// second term is classifier specific and | |
// has to be calculated outside | |
double grad(double z) | |
// second derivative (if any) | |
double hesse(double z) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment