Skip to content

Instantly share code, notes, and snippets.

@msakai
Created December 16, 2023 17:05
Show Gist options
  • Save msakai/528d33493e716bc4de3632fabbc07ba3 to your computer and use it in GitHub Desktop.
Save msakai/528d33493e716bc4de3632fabbc07ba3 to your computer and use it in GitHub Desktop.
{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
-- https://en.wikipedia.org/wiki/Okapi_BM25
module OkapiBM25
( Database
, mkDatabase
, query
) where
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IntMap
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Vector (Vector)
import qualified Data.Vector as Vector
type WordId = Int
type WordIdMap = IntMap
data Params
= Params
{ k1 :: Double
, b :: Double
}
deriving (Eq, Ord, Show, Read)
data Database a w =
Database
{ documents :: [(a, WordIdMap Int)]
, averageDocumentLength :: Double
, numDocumentsWithWord :: WordIdMap Int
, wordTable :: Vector w
, wordIdTable :: Map w Int
}
deriving (Show)
mkDatabase :: Ord w => (a -> [w]) -> [a] -> Database a w
mkDatabase wordsOf xs =
Database
{ documents = docs
, averageDocumentLength = fromIntegral (sum [sum (IntMap.elems f) | (_, f) <- docs]) / fromIntegral (length docs)
, numDocumentsWithWord = IntMap.unionsWith (+) [fmap (const 1) f | (_, f) <- docs]
, wordTable = table1
, wordIdTable = table2
}
where
docs' = map (\x -> (x, wordsOf x)) xs
wordsSet = Set.fromList $ concat $ map snd $ docs'
table1 = Vector.fromList $ Set.toList wordsSet
table2 = Map.fromList $ zip (Set.toList wordsSet) [0..]
docs = [(x, IntMap.fromListWith (+) [(table2 Map.! w, 1) | w <- ws]) | (x, ws) <- docs']
score :: Params -> Database a w -> [WordId] -> WordIdMap Int -> Double
score Params{ .. } Database{ .. } query doc =
sum
[ idf q *
(freq q * (k1 + 1))
/
(freq q + k1 * (1 - b + b * fromIntegral docLength / averageDocumentLength))
| q <- query
]
where
n = length documents
idf q = log $ 1 + (fromIntegral (n - fq) + 0.5) / (fromIntegral fq + 0.5)
where
fq = numDocumentsWithWord IntMap.! q
freq q = fromIntegral $ IntMap.findWithDefault 0 q doc
docLength = sum $ IntMap.elems doc
query :: Ord w => Params -> Database a w -> [w] -> [(a, Double)]
query params db ws = [(doc, score params people ws' f) | (doc, f) <- documents db]
where
ws' = map (wordIdTable db Map.!) ws
-- ------------------------------------------------------------------------
-- Example from
-- https://www.elastic.co/jp/blog/practical-bm25-part-2-the-bm25-algorithm-and-its-variables
-- ------------------------------------------------------------------------
people :: Database Text Text
people = mkDatabase Text.words $
[ "Shane"
, "Shane C"
, "Shane P Connelly"
, "Shane Connelly"
, "Shane Shane Connelly Connelly"
, "Shane Shane Shane Connelly Connelly Connelly"
]
example1 = query params people ["Shane"]
where
params = Params{ k1 = 0, b = 0.5 }
example2 = query params people ["Shane"]
where
params = Params{ k1 = 10, b = 0 }
example3 = query params people ["Shane"]
where
params = Params{ k1 = 5, b = 1 }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment