Skip to content

Instantly share code, notes, and snippets.

/iris.cr Secret

Created November 24, 2017 17:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/6acbb9787167f72fb2ea2b5ff17a30a9 to your computer and use it in GitHub Desktop.
Save anonymous/6acbb9787167f72fb2ea2b5ff17a30a9 to your computer and use it in GitHub Desktop.
require "csv"
require "../crystal-fann"
# https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
data = CSV.parse(File.read("./iris.data"))
puts "Size: #{data.size}" # => 150
data.shuffle!
classes = {
"Iris-setosa" => [1.0, 0.0, 0.0],
"Iris-versicolor" => [0.0, 1.0, 0.0],
"Iris-virginica" => [0.0, 0.0, 1.0],
}
# Split train and test data
train, test = data[0..100], data[100..150]
ann = Fann::Network::Standard.new(4, [10], 3)
train_X = train.map { |i| [i[0], i[1], i[2], i[3]].map(&.to_f64) }
train_y = train.map { |i| classes[i[4]] }
test_X = test.map { |i| [i[0], i[1], i[2], i[3]].map(&.to_f64) }
test_y = test.map { |i| classes[i[4]] }
10_000.times do
train_X.size.times do |i|
ann.train_single(train_X[i], train_y[i])
end
end
def result(y)
max = y.index(y.max)
Array.new(3, 0).map_with_index do |_, i|
next 1 if i == max
0
end
end
errors = 0
total = test_X.size
total.times do |i|
next errors += 1 if result(ann.run(test_X[i])) != test_y[i]
puts test_y[i]
end
puts "result: #{1 - (errors.to_f / total).round(2)} (#{errors}/#{total})"
ann.close
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment