Created
June 8, 2011 22:03
-
-
Save serialhex/1015544 to your computer and use it in GitHub Desktop.
shogun libSVM minimal example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env ruby | |
## copy from ../examples/undocumented/python_modular/classifier_libsvm_minimal_modular.py | |
require '../../../src/ruby_modular/Features' | |
require '../../../src/ruby_modular/Classifier' | |
require '../../../src/ruby_modular/Kernel' | |
# for randn func | |
require '../../../src/ruby_modular/Library' | |
# interactive console for playin wit stuff | |
require 'pry' | |
include Features | |
include Classifier | |
include Kernel | |
@num = 5 | |
@dist = 1 | |
@width = 2.1 | |
C = 1 | |
# helper methods for all this fun stuff | |
def gen_ones_vec | |
ary = [] | |
@num.times do | |
ary << -1 | |
end | |
@num.times do | |
ary << 1 | |
end | |
return ary | |
end | |
# for random numbers: | |
def randn | |
Library::Math.randn_double | |
end | |
# 2 high, num wide random arrays, concatenated together [] + [] | |
# randn - dist, randn + dist | |
def gen_rand_ary | |
ary = [[],[]] | |
ary.each do |p| | |
p << ary_fill( @dist ) + ary_fill( -@dist ) | |
p.flatten! | |
end | |
return ary | |
end | |
def ary_fill dist | |
ary = [] | |
@num.times do | |
ary << randn + dist | |
end | |
return ary | |
end | |
# this is very mean! | |
def mean stuff | |
stuff | |
end | |
Numeric.class_eval do | |
def sign | |
return -1 if self < 0 | |
return 0 if self == 0 | |
return 1 if self > 0 | |
end | |
end | |
Array.class_eval do | |
def sign | |
a = [] | |
self.each do |x| | |
a << x.sign | |
end | |
end | |
def eql_items? other | |
raise(ArgumentError, "Argument is not an Array") unless other.kind_of? Array | |
raise(ArgumentError, "Arrays dont' have the same number of elements") if self.size != other.size | |
output = [] | |
self.each_with_index do |x, i| | |
output[i] = x == other[i] | |
end | |
return output | |
end | |
end | |
# yes this is a very mean & dirty mean alg... | |
# no checking of stuff, can give wonky errors, but hey! | |
# i'm being lazy & not doing cool metaprogramming type stuff for now... | |
def mean ary | |
num_items = ary.size | |
tot = ary.inject do |sum, n| | |
if n == true | |
sum + 1 | |
next | |
end | |
next if n == false | |
sum + n | |
end | |
tot.to_f / num_items.to_f | |
end | |
# the actual example | |
puts "generating training data" | |
traindata_real = gen_rand_ary | |
testdata_real = gen_rand_ary | |
puts "generating labels" | |
trainlab = gen_ones_vec | |
testlab = gen_ones_vec | |
puts "doing feature stuff" | |
feats_train = RealFeatures.new | |
feats_train.set_feature_matrix traindata_real | |
feats_test = RealFeatures.new | |
feats_test.set_feature_matrix testdata_real | |
kernel = GaussianKernel.new feats_train, feats_train, @width | |
puts "labeling stuff" | |
labels = Labels.new | |
labels.set_labels trainlab | |
svm = LibSVM.new C, kernel, labels | |
svm.train | |
puts "the grand finale" | |
kernel.init feats_train, feats_test | |
out = svm.apply.get_labels | |
testerr = mean out.eql_items? testlab | |
puts testerr |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment