Skip to content

Instantly share code, notes, and snippets.

@atomictom
Created April 23, 2015 02:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save atomictom/b8cc50f9d42229af7d44 to your computer and use it in GitHub Desktop.
Save atomictom/b8cc50f9d42229af7d44 to your computer and use it in GitHub Desktop.
import System.Environment
import Control.Applicative
import Data.Maybe
import Data.List
import Data.List.Split
import Data.Map.Strict ((!), Map, fromList)
type Class = String
type Mean = Double
type StdDevSquared = Double
type Feature = Double
type Probability = Double
data Instance = Instance
{ getClass :: Class
, getFeatures :: [Feature]
} deriving Show
data BayesClassifier = BayesClassifier
{ getClasses :: [Class]
, getClassProbabilities :: Map Class Probability
, getGaussianDistributions :: Map Class [(Mean, StdDevSquared)]
} deriving Show
main :: IO ()
main = do
file <- fromMaybe "iris.arff" . listToMaybe <$> getArgs
data_matrix <- dataFilter <$> readFile file
let instances = do
row <- data_matrix
let cls = last row
let features = map readDouble $ init row
return (Instance cls features)
let classifier = trainClassifier instances
let classifications = map (classify classifier) instances
-- Print classifications
mapM_ putStrLn $ zipWith format instances classifications
-- Print stats
let (correct', incorrect') = partition (uncurry (==)) (zip (map getClass instances) classifications)
let (correct, incorrect) = (length correct', length incorrect')
let percentage = 100 * (fromIntegral correct / fromIntegral (correct + incorrect) :: Double)
putStrLn ""
putStrLn $ "Correct: " ++ show correct
putStrLn $ "Incorrect: " ++ show incorrect
putStrLn $ "Total: " ++ show (correct + incorrect)
putStrLn $ "Percentage correct: " ++ show percentage
where
skip :: String -> Bool
skip x = not (null x) && head x `notElem` "%@"
dataFilter :: String -> [[String]]
dataFilter = map (splitOn ",") . filter skip . lines
readDouble :: String -> Double
readDouble = read
format :: Instance -> Class -> String
format (Instance c fs) c' = "Predicting class " ++ c'
++ " for instance " ++ show fs
++ ", class: " ++ c
++ " -- " ++ (if c == c' then "Correct!" else "Incorrect")
trainClassifier :: [Instance] -> BayesClassifier
trainClassifier instances = BayesClassifier uniqueClasses classProbs probabilityDistributions
where
uniqueClasses :: [Class]
uniqueClasses = nub $ map getClass instances
classProbs :: Map Class Probability
classProbs = fromList $ do
c <- uniqueClasses
return (c, classProbability c instances)
probabilityDistributions :: Map Class [(Mean, StdDevSquared)]
probabilityDistributions = fromList $ do
c <- uniqueClasses
-- These are the training set instances that belong to class c
let classInstances = instancesByClass c instances
-- we transpose this so we can group the same features together
-- (i.e. a list of the first column, second column, etc.)
let features = transpose $ map getFeatures classInstances
let means = map mean features
let sigmaSquareds = map stdDevSquared features
return (c, zip means sigmaSquareds)
instancesByClass :: Class -> [Instance] -> [Instance]
instancesByClass c = filter ((== c) . getClass)
classProbability :: Class -> [Instance] -> Probability
classProbability c is = instances_of_class / number_of_classes
where
instances_of_class = (fromIntegral . length . instancesByClass c) is
number_of_classes = (fromIntegral . length) is
mean :: [Feature] -> Mean
mean xs = sum xs / fromIntegral (length xs)
stdDevSquared :: [Feature] -> StdDevSquared
stdDevSquared xs = mean xs'
where
xs' = map ((**2) . subtract (mean xs)) xs
classify :: BayesClassifier -> Instance -> Class
classify classifier instanceData = fst $ foldr1 argmax $ do
c <- getClasses classifier
let classProbability = getClassProbabilities classifier ! c
let distributions = getGaussianDistributions classifier ! c
let condProbabilities = zipWith ($) (map (uncurry probabilityDensity) distributions) (getFeatures instanceData)
let totalProbability = classProbability * product condProbabilities
return (c, totalProbability)
where
probabilityDensity :: Mean -> StdDevSquared -> Feature -> Probability
probabilityDensity classMean sigmaSquared attr = fractional * exponential
where
sigmaSquaredCorrected = sigmaSquared + 0.00000001
fractional = 1 / sqrt (2 * pi * sigmaSquaredCorrected)
exponential = exp ((-((attr - classMean) ** 2)) / (2 * sigmaSquaredCorrected))
argmax :: (Ord b) => (a, b) -> (a, b) -> (a, b)
argmax a1@(_, x) a2@(_, y) = if x >= y then a1 else a2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment