Skip to content

Instantly share code, notes, and snippets.

@lambday
Created March 24, 2017 05:14
Show Gist options
  • Save lambday/6e722bf56d80b91751d411c15dc5c09c to your computer and use it in GitHub Desktop.
Save lambday/6e722bf56d80b91751d411c15dc5c09c to your computer and use it in GitHub Desktop.
train_feats = sg.RealSparseFeatures('../train.dat')
test_feats = sg.RealSparseFeatures('../test.dat')
train_labels = sg.RealDenseLabels('../train_labels.dat')
gammas = np.arange(1, 10)
lda = sg.LDA()
lda.preprocessor_chain().enque(sg.DimensionSubset()).enque(sg.NormOne()).enque(sg.FisherLDA())
params = sg.ModelSelectionParameters()
params.set_values(gammas, sg.MSPT_INT32)
splitting_strategy = StratifiedCrossValidationSplittingStrategy()
cv = CrossValidation(lda, train_feats, train_labels, splitting_strategy, MeanSquarredError()).parallel()
ms = sg.GridSearchModelSelection(cv, params)
param = ms.select_model().get_parameter()
lda.set_gamma(param)
lda.set_features(train_feats)
lda.set_labels(train_labels)
lda.train()
predicted_labels = lda.apply_binary(test_feats)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment