Skip to content

Instantly share code, notes, and snippets.

@cympfh
Created March 8, 2014 09:35
Show Gist options
  • Save cympfh/9427895 to your computer and use it in GitHub Desktop.
Save cympfh/9427895 to your computer and use it in GitHub Desktop.
決定木のノードを一つだけ作る
import Data.List
data Ayame = SV | C deriving (Show, Eq)
datum :: [((Int, Int), Ayame)]
datum = [ ((5,4), SV) , ((2,3), C) , ((9,2), SV) , ((31, 39), SV) , ((20, 20), C) ]
main = do
let rule = findRule datum
(left, right) = partition rule datum
print $ gib datum
print $ (left, right)
print $ (gib left, gib right)
doRule f ls =
sub ([], []) ls
where sub ac [] = ac
sub (xs, ys) (x:rest) = sub (if f x then (x:xs,ys) else (xs,x:ys)) rest
findRule ls =
let (g1, sh1) = partBy fst ls
(g2, sh2) = partBy snd ls
sh = if g1 < g2 then sh1 else sh2
in if g1 < g2
then (\x -> fst (fst x) <= sh1)
else (\x -> snd (fst x) <= sh2)
where
partBy f ls =
let gs =
[ (g, sh) | left <- inits ls', right <- tails ls'
, let n1 = length left
, let n2 = length right
, len == n1 + n2
, let g = n1 * gib left + n2 * gib right
, let sh = div ((+) (f $ fst (last left)) (f $ fst (head right))) 2
]
in head $ sort gs
where
ls' = sortBy comp ls
len = length ls
comp a b = compare (f $ fst a) (f (fst b))
gib ls =
let n = length $ filter ((== SV) . snd) ls
m = length $ filter ((== C ) . snd) ls
in n * m * ( n + m )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment