Skip to content

Instantly share code, notes, and snippets.

@maddenpj
Created June 24, 2017 02:43
Show Gist options
  • Save maddenpj/553eb96b05fb3886eac8a277ffba9c52 to your computer and use it in GitHub Desktop.
Save maddenpj/553eb96b05fb3886eac8a277ffba9c52 to your computer and use it in GitHub Desktop.
OCR for Numbers using KNN (Haskell)
import System.IO
import Data.List
data Digit = Digit { actualDigit :: Int
, pixels :: [Int]
} deriving Show
type Classifier = [Digit]
type KNNClassifier = [[Digit]]
wordsWhen :: (Char -> Bool) -> String -> [String]
wordsWhen p s = case dropWhile p s of
"" -> []
s' -> w : wordsWhen p s''
where (w, s'') = break p s'
lineToPixel :: String -> [Int]
lineToPixel line = map read (wordsWhen (==',') line)
pixelsToDigit :: [Int] -> Digit
pixelsToDigit (actual:xs) = Digit { actualDigit = actual, pixels = xs }
lineToDigit :: String -> Digit
lineToDigit = pixelsToDigit . lineToPixel
avgFunc :: (Int, Int) -> Int
avgFunc (x, y) = (x + y) `div` 2
averageDigits :: Digit -> Digit -> Digit
averageDigits acc x = Digit {
actualDigit = actualDigit acc,
pixels = newPixels
}
where newPixels = map avgFunc zippedPixels
zippedPixels = zip (pixels acc) (pixels x)
averageGroup :: [Digit] -> Digit
averageGroup (x:xs) = foldl averageDigits x xs
sortFunc :: Digit -> Digit -> Ordering
sortFunc a b
| actualDigit a > actualDigit b = GT
| actualDigit a < actualDigit b = LT
| actualDigit a == actualDigit b = EQ
buildClassifier :: [String] -> Classifier
buildClassifier file = map averageGroup groups
where groups = groupBy (\a b -> actualDigit a == actualDigit b) sorted
sorted = sortBy sortFunc intermediate
intermediate = map operation file
operation = pixelsToDigit . lineToPixel
distance :: Digit -> [Int] -> Float
distance digit samplePixels = sum sqError
where sqError = map (**2) errors
errors = map fromIntegral iErrors
iErrors = map (\(x, y) -> x-y) $ zip (pixels digit) samplePixels
classify :: Classifier -> Digit -> Digit
classify (x:xs) sampleDigit = foldl minBy x xs
where minBy minDigit newDigit = if (distance newDigit samplePixels) < (distance minDigit samplePixels) then newDigit else minDigit
samplePixels = pixels sampleDigit
distanceSortFunc :: [Int] -> Digit -> Digit -> Ordering
distanceSortFunc sample a b
| da > db = GT
| da < db = LT
| da == db = EQ
where da = distance a sample
db = distance b sample
knnClassify :: Classifier -> Digit -> Digit
knnClassify classifier sampleDigit = head distanceSorted
where distanceSorted = sortBy (distanceSortFunc (pixels sampleDigit)) classifier
buildKnnClassifier :: [String] -> Classifier
buildKnnClassifier file = map lineToDigit file
testDigit :: (Digit, Digit) -> Bool
testDigit (validation, classified) = (actualDigit validation) == (actualDigit classified)
prepFile :: [String] -> [String]
prepFile file = drop 1 $ lines validation
-- [Digit] Clasifier
main = do
file <- readFile "digitssample.csv"
putStrLn "Building Classifier"
let classifier = buildKnnClassifier $ prepFile file
putStrLn "Done. Reading in validation File"
validation <- readFile "digitscheck.csv"
let validationSet = map lineToDigit $ take 5 $ prepFile validation
putStrLn "Classifying"
let classifiedValidationSet = map (knnClassify classifier) validationSet
putStrLn "Checking success"
let successSet = map testDigit $ zip validationSet classifiedValidationSet
putStrLn "Num Correct: "
putStrLn $ show $ sum $ map fromEnum successSet
putStrLn "Num Total: "
putStrLn $ show $ length validationSet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment