Skip to content

Instantly share code, notes, and snippets.

@deepredsky
Created December 30, 2016 17:05
Show Gist options
  • Save deepredsky/955ba3859f873ae8e38d7d32a9c2fa52 to your computer and use it in GitHub Desktop.
Save deepredsky/955ba3859f873ae8e38d7d32a9c2fa52 to your computer and use it in GitHub Desktop.
defmodule KNN do
defmodule LabeLWithFeatures do
defstruct label: [], pixels: []
end
def slurp_file(file) do
File.read!(file)
|> String.split("\n")
|> Stream.drop(1)
|> Stream.map(fn row ->
csv_row_to_label(String.split(row, ","))
end)
|> Enum.to_list()
end
def csv_row_to_label([label|pixels]) do
%LabeLWithFeatures{label: label, pixels: pixels |> Enum.map(fn num -> parse_binary_to_float(num) end)}
end
def parse_binary_to_float(binary) do
{float, _} = Integer.parse(binary)
float
end
def distance(x_pixels, y_pixels) do
x_pixels
|> Stream.flat_map(fn x -> (Stream.map(y_pixels, fn y -> :math.pow(x-y, 2) end )) end)
|> Enum.to_list()
|> Enum.sum
end
def classify(training_labels, pixels) do
min_training_label =
training_labels
|> Enum.min_by(fn training_label -> distance(training_label.pixels, pixels) end)
min_training_label.label
end
def main() do
trainingsample = slurp_file("trainingsample.csv")
validationsample = slurp_file("validationsample.csv")
total_correct_item =
validationsample
|> Stream.filter(fn validation_item -> (classify(trainingsample, validation_item.pixels) == validation_item.label) end)
|> Enum.to_list()
|> Enum.count()
IO.puts "Percentage correct: #{(total_correct_item/Enum.count(validationsample)) * 100}"
end
end
KNN.main()
# IO.inspect NearestNeighbour.slurp_file("trainingsample.csv") |> hd()
# IO.puts NearestNeighbour.calculate_total_correct("trainingsample.csv", "validationsample.csv")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment