Last active
December 17, 2015 08:39
-
-
Save benjamin-bader/5581393 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
open System | |
open System.IO | |
[<Literal>] | |
let trainingDataFile = @"C:\Users\ben\Downloads\digitssample.csv" | |
type ImageData = { | |
ActualNumber: byte | |
Pixels: float [] | |
} | |
let readData f = | |
File.ReadAllLines f |> Seq.skip 1 | |
|> Seq.map (fun str -> str.Split(',')) | |
|> Seq.map (Array.map Convert.ToByte) | |
|> Seq.map (fun parts -> parts.[0], parts.[1 .. parts.Length - 1]) | |
|> Seq.map (fun (num, pixels) -> { ActualNumber = num; Pixels = Array.map float pixels }) | |
|> Array.ofSeq | |
let data = readData trainingDataFile | |
let distanceOf (xs: float[]) (ys: float[]) = | |
let mutable sum = 0. | |
for x in 0..(xs.Length - 1) do | |
let v = pown (xs.[x] - ys.[x]) 2 | |
sum <- sum + v | |
sum | |
let compareImages imageOne imageTwo = | |
let pixelsOne = imageOne.Pixels | |
let pixelsTwo = imageTwo.Pixels | |
distanceOf pixelsOne pixelsTwo | |
let findNearestNeighbor image neighbors = neighbors |> Array.minBy (fun n -> compareImages image n) | |
let classify (unknown: float[]) = | |
let image = { ActualNumber = 0uy; Pixels = unknown} | |
let nearestNeighbor = findNearestNeighbor image data | |
nearestNeighbor.ActualNumber | |
let testData = readData @"C:\Users\ben\Downloads\digitscheck.csv" | |
let totalRecords = ref 0 | |
let correctRecords = ref 0 | |
let testAnImage image = async { | |
let nearestNeighbor = classify image.Pixels | |
lock testData (fun () -> | |
incr totalRecords | |
if nearestNeighbor = image.ActualNumber then | |
incr correctRecords) | |
} | |
let sw = new System.Diagnostics.Stopwatch() | |
sw.Start() | |
testData |> Seq.map testAnImage | |
|> Async.Parallel | |
|> Async.RunSynchronously | |
|> ignore | |
sw.Stop() | |
printfn "Percent correct: %f (%d ms)" (float !correctRecords / float !totalRecords) sw.ElapsedMilliseconds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment