Skip to content

Instantly share code, notes, and snippets.

@purcell
Created October 21, 2015 08:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save purcell/d13b7a8a53c7b8a899ce to your computer and use it in GitHub Desktop.
Save purcell/d13b7a8a53c7b8a899ce to your computer and use it in GitHub Desktop.
naive bayes
module Classifier
( Classifier(..)
, empty
, update
, union
, classify
, singleton
, scaled
, test
) where
import Data.Function (on)
import Data.List (nub, sortBy)
import Data.Map (Map)
import qualified Data.Map as M
import Data.Monoid (Monoid (..))
import Prelude hiding (product)
{-
References:
* http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html
* http://www.research.ibm.com/people/z/zadrozny/kdd2002-Transf.pdf (Normalisation)
* https://en.wikipedia.org/wiki/Naive_Bayes_classifier
* http://ebiquity.umbc.edu/blogger/2010/12/07/naive-bayes-classifier-in-50-lines/
* http://arxiv.org/pdf/1004.5168v1.pdf
* https://code.google.com/p/ourmine/wiki/LectureNaiveBayes
* http://www.randomhacks.net/articles/2007/02/22/bayes-rule-and-drug-tests
* https://github.com/CamDavidsonPilon/Probabilistic-Programming-and-Bayesian-Methods-for-Hackers
* http://sebastianraschka.com/Articles/2014_naive_bayes_1.html
* http://gigamonkeys.com/book/practical-a-spam-filter.html
-}
type FeatureCounts c = Map c Int
type CategoryInfo c = (Int, FeatureCounts c)
-- | A classifier for classifying sets of discrete features of type "c" into categories of type "b"
data Classifier b c = Classifier { categorisations :: Map b (CategoryInfo c) } deriving Show
instance (Ord b, Ord c) => Monoid (Classifier b c) where
mempty = empty
mappend = union
mconcat = foldr union empty
type Score = Double
-- Public interface
empty :: Classifier b c
empty = Classifier M.empty
singleton :: (Ord b, Ord c) => b -> [c] -> Classifier b c
singleton cat features = Classifier (M.singleton cat (1, featureCounts))
where featureCounts = M.fromListWith (+) [(f, 1) | f <- features]
update :: (Ord b, Ord c) => Classifier b c -> b -> [c] -> Classifier b c
update c cat features = c `union` singleton cat features
union :: (Ord b, Ord c) => Classifier b c -> Classifier b c -> Classifier b c
union c1 c2 = Classifier $ M.unionWith mergeCat (categorisations c1) (categorisations c2)
classify :: (Ord b, Ord c) => Classifier b c -> [c] -> [(b, Score)]
classify c features = sortByScore [(cat, categoryScore c fCounts features) | (cat, fCounts) <- M.toList $ categorisations c]
where sortByScore = sortBy (flip compare `on` snd)
categoryScore :: (Ord b, Ord c) => Classifier b c -> CategoryInfo c -> [c] -> Score
categoryScore c (_, fCounts) features = product $ map pFeature features
where pFeature f = (laplaceSmoothing + featureCount f fCounts) `ratio`
(totalFeatureCount fCounts + laplaceSmoothing * numDistinctFeatures c)
laplaceSmoothing = 1
scaled :: [(b, Score)] -> [(b, Score)]
scaled scoredCategories = map (\(c', s) -> (c', s / scaleFactor)) scoredCategories
where scaleFactor = sum $ map snd scoredCategories
-- Internals
-- | Prevents floating point underflow
product :: Floating a => [a] -> a
product = exp . sum . map log
ratio :: (Integral a, Floating b) => a -> a -> b
ratio = (/) `on` fromIntegral
featureCount :: (Ord c) => c -> FeatureCounts c -> Int
featureCount = M.findWithDefault 0
totalFeatureCount :: (Ord c) => FeatureCounts c -> Int
totalFeatureCount = M.foldr (+) 0
mergeCat :: Ord c => CategoryInfo c -> CategoryInfo c -> CategoryInfo c
mergeCat (newN, newCounts) (prevN, prevCounts) = (newN + prevN, M.unionWith (+) prevCounts newCounts)
numDistinctFeatures :: (Ord b, Ord c) => Classifier b c -> Int
numDistinctFeatures c = length $ nub $ concatMap (M.keys . snd) $ M.elems $ categorisations c
totalCategorisations :: (Ord b, Ord c) => Classifier b c -> Int
totalCategorisations c = sum $ map fst $ M.elems $ categorisations c
-- Worked example from http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html
data Cat = Chinese | NotChinese deriving (Show, Eq, Ord)
test :: Classifier Cat String
test = mconcat $ map doc [(Chinese, "Chinese Beijing Chinese")
,(Chinese, "Chinese Chinese Shanghai")
,(Chinese, "Chinese Macao")
,(NotChinese, "Tokyo Japan Chinese")
]
where doc (cat, text) = singleton cat $ words text
-- λ> classify test ["Chinese", "Chinese", "Chinese", "Tokyo", "Japan"]
-- [(Chinese,3.0121377997263e-4),(NotChinese,1.3548070246744215e-4)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment