Skip to content

Instantly share code, notes, and snippets.

@djv
Last active August 29, 2015 14:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save djv/3a31dcc044223213c3f1 to your computer and use it in GitHub Desktop.
Save djv/3a31dcc044223213c3f1 to your computer and use it in GitHub Desktop.
{-# LANGUAGE BangPatterns #-}
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as UV
import qualified Data.ByteString.Char8 as C
import Data.Vector.Strategies
data LabelPixel = LabelPixel {_label :: !Int, _pixels :: !(UV.Vector Int)}
trim :: C.ByteString -> C.ByteString
trim = C.reverse . C.dropWhile (`elem` " \t") . C.reverse . C.dropWhile (`elem` " \t")
slurpFile :: FilePath -> IO (V.Vector LabelPixel)
slurpFile = fmap (V.fromList . map make . tail . C.lines) . C.readFile where
readInt' !bs = let Just (i, _) = C.readInt bs in i
labelPixel (lbl:pxls) = LabelPixel (readInt' lbl) (UV.fromList $ map readInt' pxls)
make = labelPixel . C.split ',' . trim
classify :: V.Vector LabelPixel -> UV.Vector Int -> Int
classify !training !pixels = _label mini where
dist !x !y = UV.sum . UV.map (^2) $ UV.zipWith (-) x y
comp !p1 !p2 = dist (_pixels p1) pixels `compare` dist (_pixels p2) pixels
mini = V.minimumBy comp training
main :: IO ()
main = do
trainingSet <- slurpFile "trainingsample.csv"
validationSample <- slurpFile "validationsample.csv"
let isCorrect x = fromEnum $ classify trainingSet (_pixels x) == _label x
numCorrect = V.sum $
((V.map isCorrect validationSample) `using` (parVector 50))
flt = fromIntegral
percentCorrect = flt numCorrect / flt (V.length validationSample) * 100.0
putStrLn $ "Percentage correct: " ++ show percentCorrect
@djv
Copy link
Author

djv commented Jun 11, 2014

Run with:

ghc -O2 run.hs -fllvm -threaded
time ./run +RTS -N8

The data files are at https://github.com/philtomson/ClassifyDigits

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment