Created
June 24, 2017 02:43
-
-
Save maddenpj/553eb96b05fb3886eac8a277ffba9c52 to your computer and use it in GitHub Desktop.
OCR for Numbers using KNN (Haskell)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import System.IO | |
import Data.List | |
data Digit = Digit { actualDigit :: Int | |
, pixels :: [Int] | |
} deriving Show | |
type Classifier = [Digit] | |
type KNNClassifier = [[Digit]] | |
wordsWhen :: (Char -> Bool) -> String -> [String] | |
wordsWhen p s = case dropWhile p s of | |
"" -> [] | |
s' -> w : wordsWhen p s'' | |
where (w, s'') = break p s' | |
lineToPixel :: String -> [Int] | |
lineToPixel line = map read (wordsWhen (==',') line) | |
pixelsToDigit :: [Int] -> Digit | |
pixelsToDigit (actual:xs) = Digit { actualDigit = actual, pixels = xs } | |
lineToDigit :: String -> Digit | |
lineToDigit = pixelsToDigit . lineToPixel | |
avgFunc :: (Int, Int) -> Int | |
avgFunc (x, y) = (x + y) `div` 2 | |
averageDigits :: Digit -> Digit -> Digit | |
averageDigits acc x = Digit { | |
actualDigit = actualDigit acc, | |
pixels = newPixels | |
} | |
where newPixels = map avgFunc zippedPixels | |
zippedPixels = zip (pixels acc) (pixels x) | |
averageGroup :: [Digit] -> Digit | |
averageGroup (x:xs) = foldl averageDigits x xs | |
sortFunc :: Digit -> Digit -> Ordering | |
sortFunc a b | |
| actualDigit a > actualDigit b = GT | |
| actualDigit a < actualDigit b = LT | |
| actualDigit a == actualDigit b = EQ | |
buildClassifier :: [String] -> Classifier | |
buildClassifier file = map averageGroup groups | |
where groups = groupBy (\a b -> actualDigit a == actualDigit b) sorted | |
sorted = sortBy sortFunc intermediate | |
intermediate = map operation file | |
operation = pixelsToDigit . lineToPixel | |
distance :: Digit -> [Int] -> Float | |
distance digit samplePixels = sum sqError | |
where sqError = map (**2) errors | |
errors = map fromIntegral iErrors | |
iErrors = map (\(x, y) -> x-y) $ zip (pixels digit) samplePixels | |
classify :: Classifier -> Digit -> Digit | |
classify (x:xs) sampleDigit = foldl minBy x xs | |
where minBy minDigit newDigit = if (distance newDigit samplePixels) < (distance minDigit samplePixels) then newDigit else minDigit | |
samplePixels = pixels sampleDigit | |
distanceSortFunc :: [Int] -> Digit -> Digit -> Ordering | |
distanceSortFunc sample a b | |
| da > db = GT | |
| da < db = LT | |
| da == db = EQ | |
where da = distance a sample | |
db = distance b sample | |
knnClassify :: Classifier -> Digit -> Digit | |
knnClassify classifier sampleDigit = head distanceSorted | |
where distanceSorted = sortBy (distanceSortFunc (pixels sampleDigit)) classifier | |
buildKnnClassifier :: [String] -> Classifier | |
buildKnnClassifier file = map lineToDigit file | |
testDigit :: (Digit, Digit) -> Bool | |
testDigit (validation, classified) = (actualDigit validation) == (actualDigit classified) | |
prepFile :: [String] -> [String] | |
prepFile file = drop 1 $ lines validation | |
-- [Digit] Clasifier | |
main = do | |
file <- readFile "digitssample.csv" | |
putStrLn "Building Classifier" | |
let classifier = buildKnnClassifier $ prepFile file | |
putStrLn "Done. Reading in validation File" | |
validation <- readFile "digitscheck.csv" | |
let validationSet = map lineToDigit $ take 5 $ prepFile validation | |
putStrLn "Classifying" | |
let classifiedValidationSet = map (knnClassify classifier) validationSet | |
putStrLn "Checking success" | |
let successSet = map testDigit $ zip validationSet classifiedValidationSet | |
putStrLn "Num Correct: " | |
putStrLn $ show $ sum $ map fromEnum successSet | |
putStrLn "Num Total: " | |
putStrLn $ show $ length validationSet | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment