Skip to content

Instantly share code, notes, and snippets.

@pierric
Created March 1, 2016 13:15
Show Gist options
  • Save pierric/35751cb7f65a9185b8d2 to your computer and use it in GitHub Desktop.
Save pierric/35751cb7f65a9185b8d2 to your computer and use it in GitHub Desktop.
CYK algorithm for PCFG
type Word = String
type NonTerminal = B.ByteString
data UnaryRule = NonTerminal :- Word deriving (Eq, Show)
data BinaryRule = NonTerminal := (NonTerminal, NonTerminal) deriving (Eq, Show)
instance Hashable UnaryRule where
hashWithSalt s (n :- w) = s `hashWithSalt` n `hashWithSalt` w
instance Hashable BinaryRule where
hashWithSalt s (n := (j,k)) = s `hashWithSalt` n `hashWithSalt` j `hashWithSalt` k
class CorpusModel m where
nonterminals :: m -> [NonTerminal]
nonterminal_count :: m -> NonTerminal -> Int
unaryrule_count :: m -> UnaryRule -> Int
binaryrule_count :: m -> BinaryRule -> Int
binaryrules_start_with :: m -> NonTerminal -> [BinaryRule]
cyk :: CorpusModel m => m -> NonTerminal -> [Word] -> Maybe ParseTree
cyk m s w = snd $ pi (1, sent_len, start_symbol)
where
sent_len = length w
sentence = listArray (1,sent_len) w
nonterms = nonterminals m
nont_len = length nonterms
nonterms_array = listArray (1, nont_len) nonterms
nonterms_index = M.fromList $ zip nonterms [1..]
index k = fromJust $ M.lookup k nonterms_index
start_symbol = index s
pi (i,j,nt) | i == j = let nont = nonterms_array!#nt
word = sentence!#i
in (uprob (nont:-word), Just $ RuleU nont word)
| i < j = let nont = nonterms_array!#nt
candindates = [ ( p1 * p2 * prob, Just $ RuleB nont r1 r2)
| k <- [i..j-1]
, rule@(_:=(nt1,nt2)) <- binaryrules_start_with m nont
, let (p1, mr1) = pi_cached!#(i, k, index nt1)
, p1 > 0, r1 <- maybeToList mr1
, let (p2, mr2) = pi_cached!#(k+1,j, index nt2)
, p2 > 0, r2 <- maybeToList mr2
, let prob = bprob rule
, prob > 0]
value_compare a b = compare (fst a) (fst b)
in if null candindates then
(0, Nothing)
else
maximumBy value_compare candindates
| i > j = error ("i should not be greater than j, i=" ++ show i ++ " j=" ++ show j)
pi_cached = let index_range = ((1,1,1), (sent_len,sent_len,nont_len))
in listArray index_range (map pi (range index_range))
uprob r@(n:-_) = fromIntegral (unaryrule_count m r) / fromIntegral (nonterminal_count m n) :: Double
bprob r@(n:=_) = fromIntegral (binaryrule_count m r) / fromIntegral (nonterminal_count m n) :: Double
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment