Created
April 25, 2010 22:25
-
-
Save scvalex/378780 to your computer and use it in GitHub Desktop.
Owari in Haskell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Main where | |
import Data.Array.IArray | |
import Data.Array.Unboxed | |
import Data.Function | |
import Data.List | |
import Debug.Trace ( trace ) | |
import System.IO.Unsafe ( unsafePerformIO ) | |
import Text.Printf ( printf ) | |
---------------------------------------- | |
-- DATA TYPES | |
---------------------------------------- | |
data Player = Max | Min | |
deriving ( Eq, Show ) | |
otherPlayer :: Player -> Player | |
otherPlayer Max = Min | |
otherPlayer Min = Max | |
data GameState = GameState | |
{ firstPlayer :: !Player | |
, player :: !Player | |
, board :: !(UArray Int Int) | |
, score :: !(Player -> Int) | |
, value :: !Int | |
} | |
instance Show GameState where | |
show gs = unlines [ printf "%s %d | %s" (marker 1) (score gs (otherPlayer $ firstPlayer gs)) (intercalate " " . reverse $ bowls (7, 12) gs) | |
, printf "%s %d | %s" (marker 2) (score gs (firstPlayer gs)) (intercalate " " $ bowls (1, 6) gs) | |
] | |
where | |
marker 1 = if firstPlayer gs /= player gs then ">" else " " | |
marker 2 = if firstPlayer gs == player gs then ">" else " " | |
bowls (f, l) = map (printf "%2d".snd) . filter (\(i, x) -> f <= i && i <= l) . assocs . board | |
instance Eq GameState where | |
a == b = board a == board b && player a == player b | |
&& firstPlayer a == firstPlayer b | |
instance Ord GameState where | |
compare a b = compare (value a) (value b) | |
data GameTree = Node { state :: !GameState | |
, nextStates :: !(Array Int GameTree) | |
, complete :: !Bool | |
} | |
| Invalid | |
| Empty | |
deriving ( Eq ) | |
instance Ord GameTree where | |
compare a b | |
| a == Empty || a == Invalid = LT | |
| b == Empty || b == Invalid = GT | |
| otherwise = compare (state a) (state b) | |
newNode :: GameState -> GameTree | |
newNode gs = Node gs (listArray (1, 12) (repeat Empty)) False | |
isNode :: GameTree -> Bool | |
isNode (Node _ _ _) = True | |
isNode _ = False | |
getTreeValueString :: GameTree -> String | |
getTreeValueString Invalid = "X" | |
getTreeValueString Empty = "_" | |
getTreeValueString n = show . value . state $ n | |
instance Show GameTree where | |
show Empty = "-?-" | |
show Invalid = "-X-" | |
show n = unlines [ "Current:" | |
, intercalate "\n" . map ("\t"++) . lines . show $ state n | |
, "Next:" | |
, "\t" ++ (intercalate " " . reverse $ positions (7, 12) n) | |
, "\t" ++ (intercalate " " $ positions (1, 6) n) | |
] | |
where | |
positions (f, l) = map (printf "%2s" . getTreeValueString . snd) | |
. filter (\(i, x) -> f <= i && i <= l) . assocs . nextStates | |
---------------------------------------- | |
-- CONSTANTS | |
---------------------------------------- | |
initialState :: GameState | |
initialState = GameState | |
{ firstPlayer = Max | |
, player = Max | |
, board = listArray (1, 12) (repeat 4) | |
, score = const 0 | |
, value = 0 | |
} | |
---------------------------------------- | |
-- OWARI LOGIC | |
---------------------------------------- | |
doEmptyMove :: GameState -> GameState | |
doEmptyMove gs = gs { player = otherPlayer $ player gs | |
, value = score gs (otherPlayer $ player gs) - score gs (player gs) | |
} | |
doMove :: GameState -> Int -> Maybe GameState | |
doMove gs m | |
| isValidMove gs m = Just $ gs { player = newPlayer | |
, board = newBoard' | |
, score = newScore | |
, value = if newScore newPlayer >= 24 then 1000 else newValue | |
} | |
| otherwise = Nothing | |
where | |
newPlayer = otherPlayer $ player gs | |
-- indices of the bowls that were hit | |
hits :: Int -> Int -> [Int] | |
hits 0 _ = [] | |
hits stones i = i : hits (stones-1) (normalizedIndex (i+1)) | |
hit = hits (board gs ! m) (normalizedIndex (m+1)) | |
-- indices of the bowls that were captured | |
captured = nub [i | i <- hit, board gs ! i == 1, i /= m] | |
newScore = \p -> if p == newPlayer | |
then score gs p | |
else score gs p + length captured + length (filter (`elem`captured) hit) | |
newBoard = board gs // ((m, 0) : [(i, 0) | i <- captured]) | |
newBoard' = accum (+) newBoard [(i, 1) | i <- hit, i `notElem` captured] | |
normalizedIndex i = (i-1) `mod` 12 + 1 | |
newValue = newScore (otherPlayer $ player gs) - newScore (player gs) | |
advanceTree :: GameTree -> Int -> GameTree | |
advanceTree gt m = nextStates gt ! m | |
isValidMove :: GameState -> Int -> Bool | |
isValidMove gs m | |
| player gs == firstPlayer gs = 1 <= m && m <= 6 && board gs ! m /= 0 | |
| player gs /= firstPlayer gs = 7 <= m && m <= 12 && board gs ! m /= 0 | |
canMove :: GameState -> Bool | |
canMove gs = any (isValidMove gs) [1..12] | |
---------------------------------------- | |
-- AI | |
---------------------------------------- | |
completeNode :: GameTree -> GameTree | |
completeNode Empty = Empty | |
completeNode Invalid = Invalid | |
completeNode (Node s ns c) | |
| canMove s = let ns' = listArray (1, 12) $ map (maybe Invalid newNode . doMove s) [1..12] | |
in Node s ns' True | |
| otherwise = Node s (listArray (1, 12) $ newNode (doEmptyMove s) : repeat Invalid) True | |
ensureDepth :: Int -> GameTree -> GameTree | |
ensureDepth _ Empty = Empty | |
ensureDepth _ Invalid = Invalid | |
ensureDepth 0 gt = gt | |
ensureDepth d gt | |
| complete gt = Node (state gt) (amap (ensureDepth (d-1)) (nextStates gt)) True | |
| otherwise = ensureDepth d (completeNode gt) | |
nodesIn :: GameTree -> Int | |
nodesIn Empty = 0 | |
nodesIn Invalid = 0 | |
nodesIn (Node _ ns _) = 1 + sum (map nodesIn $ elems ns) | |
minimax :: Int -> GameTree -> GameTree | |
minimax _ Empty = Empty | |
minimax _ Invalid = Invalid | |
minimax 0 gt = gt | |
minimax d gt | |
| complete gt = let nts = amap (minimax (d-1)) (nextStates gt) | |
newValue = maximum . map ((0-).value.state) | |
. filter isNode $ elems nts | |
in Node ((state gt) { value = newValue }) nts True | |
| otherwise = minimax d (completeNode gt) | |
bestMove :: GameTree -> Int | |
bestMove gt = fst . minimumBy (compare `on` snd) . filter (isNode.snd) . assocs $ nextStates gt | |
alphabeta :: Int -> GameTree -> GameTree | |
alphabeta d gt = go d gt (-1000) 1000 | |
where | |
go :: Int -> GameTree -> Int -> Int -> GameTree | |
go _ Empty _ _ = Empty | |
go _ Invalid _ _ = Invalid | |
go 0 gt _ _ = gt | |
go d gt a b | |
| complete gt = let nts = accum (\_ x -> x) (nextStates gt) (goChildren (filter (isNode.snd) . assocs $ nextStates gt) a) | |
newValue = maximum . map ((0-).value.state) | |
. filter isNode $ elems nts | |
in Node ((state gt) { value = newValue }) nts True | |
| otherwise = go d (completeNode gt) a b | |
where | |
goChildren :: [(Int, GameTree)] -> Int -> [(Int, GameTree)] | |
goChildren [] _ = [] | |
goChildren ((i, t):ts) a = let nt = go (d-1) t (-b) (-a) | |
a' = (0-) . value . state $ nt | |
in (i,nt) : if (b <= a') | |
then ts | |
else goChildren ts a' | |
---------------------------------------- | |
-- UI | |
---------------------------------------- | |
main :: IO () | |
main = do | |
putStrLn "Shall I go first? [y/n]" | |
ans <- getLine | |
if ans == "y" | |
then gameLoop 1 11 (newNode initialState { player = Max }) | |
else gameLoop 1 11 (newNode initialState { player = Max }) | |
gameLoop :: Int -> Int -> GameTree -> IO () | |
gameLoop moveNum depth gt = do | |
let s = state gt | |
if score s (firstPlayer s) >= 24 | |
then do | |
printf "First player WON\n\nFinal state:\n" | |
print s | |
else if score s (otherPlayer $ firstPlayer s) >= 24 | |
then do | |
printf "Second player WON\n\nFinal state:\n" | |
print s | |
else do | |
printf "Computing (%d)...\n" (depth + moveNum `div` 4) | |
let gt' = alphabeta (depth + moveNum `div` 4) gt | |
bm = bestMove gt' | |
print gt' | |
if canMove s | |
then do | |
printf "I recommend you move %d\n\n" bm | |
m <- getMove | |
gameLoop (moveNum+1) depth $ advanceTree gt' m | |
else do | |
printf "You can't move\n" | |
gameLoop (moveNum+1) depth $ advanceTree gt' 1 -- empty move | |
where | |
getMove :: IO Int | |
getMove = do | |
printf "What is your move?\n" | |
m <- return . read =<< getLine | |
if isValidMove (state gt) m | |
then return m | |
else do | |
printf "Invalid move: %d\n" m | |
getMove |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment