Created
March 1, 2016 13:15
-
-
Save pierric/35751cb7f65a9185b8d2 to your computer and use it in GitHub Desktop.
CYK algorithm for PCFG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type Word = String | |
type NonTerminal = B.ByteString | |
data UnaryRule = NonTerminal :- Word deriving (Eq, Show) | |
data BinaryRule = NonTerminal := (NonTerminal, NonTerminal) deriving (Eq, Show) | |
instance Hashable UnaryRule where | |
hashWithSalt s (n :- w) = s `hashWithSalt` n `hashWithSalt` w | |
instance Hashable BinaryRule where | |
hashWithSalt s (n := (j,k)) = s `hashWithSalt` n `hashWithSalt` j `hashWithSalt` k | |
class CorpusModel m where | |
nonterminals :: m -> [NonTerminal] | |
nonterminal_count :: m -> NonTerminal -> Int | |
unaryrule_count :: m -> UnaryRule -> Int | |
binaryrule_count :: m -> BinaryRule -> Int | |
binaryrules_start_with :: m -> NonTerminal -> [BinaryRule] | |
cyk :: CorpusModel m => m -> NonTerminal -> [Word] -> Maybe ParseTree | |
cyk m s w = snd $ pi (1, sent_len, start_symbol) | |
where | |
sent_len = length w | |
sentence = listArray (1,sent_len) w | |
nonterms = nonterminals m | |
nont_len = length nonterms | |
nonterms_array = listArray (1, nont_len) nonterms | |
nonterms_index = M.fromList $ zip nonterms [1..] | |
index k = fromJust $ M.lookup k nonterms_index | |
start_symbol = index s | |
pi (i,j,nt) | i == j = let nont = nonterms_array!#nt | |
word = sentence!#i | |
in (uprob (nont:-word), Just $ RuleU nont word) | |
| i < j = let nont = nonterms_array!#nt | |
candindates = [ ( p1 * p2 * prob, Just $ RuleB nont r1 r2) | |
| k <- [i..j-1] | |
, rule@(_:=(nt1,nt2)) <- binaryrules_start_with m nont | |
, let (p1, mr1) = pi_cached!#(i, k, index nt1) | |
, p1 > 0, r1 <- maybeToList mr1 | |
, let (p2, mr2) = pi_cached!#(k+1,j, index nt2) | |
, p2 > 0, r2 <- maybeToList mr2 | |
, let prob = bprob rule | |
, prob > 0] | |
value_compare a b = compare (fst a) (fst b) | |
in if null candindates then | |
(0, Nothing) | |
else | |
maximumBy value_compare candindates | |
| i > j = error ("i should not be greater than j, i=" ++ show i ++ " j=" ++ show j) | |
pi_cached = let index_range = ((1,1,1), (sent_len,sent_len,nont_len)) | |
in listArray index_range (map pi (range index_range)) | |
uprob r@(n:-_) = fromIntegral (unaryrule_count m r) / fromIntegral (nonterminal_count m n) :: Double | |
bprob r@(n:=_) = fromIntegral (binaryrule_count m r) / fromIntegral (nonterminal_count m n) :: Double | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment