Created
July 23, 2020 21:25
-
-
Save bananabrick/08144fad41a80f089d85eee769b4c99b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Vea( | |
Var(..), Row, Figure, | |
pp, buildFigure, normalize, restrict, | |
multiply, sumOut, keepBest | |
) where | |
import Utils | |
import Data.List (intercalate) | |
import Data.Maybe (mapMaybe) | |
import Data.Monoid ((<>)) | |
import qualified Data.Set as Set | |
import qualified Data.Map as Map | |
data Var = Var { | |
-- Possible values allowed by the variable. | |
domain :: [Int], | |
name :: String | |
} deriving (Eq, Show) | |
instance Ord Var where | |
compare v1 v2 = compare (name v1) (name v2) | |
data Row = Row { | |
value :: Double, | |
mappings :: Map.Map Var Int | |
} deriving (Eq, Show) | |
instance Monoid Row where | |
mempty = Row{value=1, mappings=Map.empty} | |
-- Merge the maps, and multiply the values. | |
(Row val1 map1) `mappend` (Row val2 map2) = | |
Row (val1 * val2) (Map.union map1 map2) | |
data Figure = Figure { | |
variables :: [Var], | |
rows :: [Row] | |
} deriving (Eq, Show) | |
crossVars :: [Var] -> [[Int]] | |
crossVars vars = crossProduct (map domain vars) | |
rowValueToRow :: [Var] -> ([Int], Double) -> Row | |
rowValueToRow vars (row, value) = Row { | |
mappings=Map.fromList $ zip vars row, value=value | |
} | |
valuesToRow :: [Var] -> [Double] -> [[Int]] -> [Row] | |
valuesToRow vars values rows = map (rowValueToRow vars) (zip rows values) | |
pprintRow :: [Var] -> Row -> String | |
pprintRow vars (Row value mappings) = unwords $ rowString ++ [show value] | |
where rowString = mapMaybe (\v -> show <$> Map.lookup v mappings) vars | |
pprintFigure :: Figure -> String | |
pprintFigure (Figure vars rows) = intercalate "\n" finalString | |
where | |
varsString = unwords $ map name vars | |
rowsString = map (pprintRow vars) rows | |
finalString = varsString:rowsString | |
buildFigure :: [Var] -> [Double] -> Maybe Figure | |
buildFigure vars values | |
| length values == length rows && uniqueNames && uniqueDomains = | |
Just (Figure vars $ valuesToRow vars values rows) | |
| otherwise = Nothing | |
where | |
rows = crossVars vars | |
uniqueNames = Set.size (Set.fromList $ (map name vars)) == length vars | |
uniqueDomains = all ((\d -> Set.size (Set.fromList d) == length d) . domain) vars | |
-- Operations on Figures: | |
-- Normalize a figure. | |
-- Restrict a figure to a value. | |
-- Keep best rows in a figure. | |
-- Sum out a figure. | |
-- Multiply two figures. | |
normalize :: Figure -> Figure | |
-- Divide by 0? | |
normalize (Figure vars rows) = Figure vars newRows | |
where | |
valSum = sum $ map value rows | |
newRows = map (\r -> r{value=value r / valSum}) rows | |
varValInRow :: Var -> Int -> Row -> Bool | |
-- Var must be a member of mappings in Row. | |
varValInRow var varVal (Row _ mappings) = | |
case Map.lookup var mappings of | |
Just v -> v == varVal | |
Nothing -> False | |
varFromName :: String -> [Var] -> Maybe Var | |
varFromName varName vars = | |
case filter (\v -> name v == varName) vars of | |
[] -> Nothing | |
[var] -> Just var | |
-- More than one var with the same name == broken invariant. | |
-- which is impossible, if figure was built using `buildFigure` | |
_ -> Nothing | |
verifyDomain :: Var -> Int -> Bool | |
-- Returns True if Int belongs to domain of the variable. | |
verifyDomain var varVal = varVal `elem` domain var | |
dropFromRows :: Var -> [Row] -> [Row] | |
-- Drops the var from a list of rows. | |
dropFromRows var = map (\r -> r{mappings=Map.delete var (mappings r)}) | |
dropFromFigure :: Var -> Figure -> Figure | |
dropFromFigure var (Figure vars rows) = | |
Figure (filter (/= var) vars) (dropFromRows var rows) | |
restrict :: String -> Int -> Figure -> Maybe Figure | |
restrict varName varVal (Figure vars rows) = | |
case varFromName varName vars of | |
Nothing -> Nothing | |
(Just var) -> | |
case verifyDomain var varVal of | |
-- Drop var from figure because it has been restricted | |
-- to exactly one value. | |
True -> Just (dropFromFigure var $ Figure vars newRows) | |
where newRows = filter (varValInRow var varVal) rows | |
False -> Nothing | |
keepBest :: String -> Figure -> Maybe Figure | |
-- Keeps the best rows in the figure. | |
-- Kind of like sumOut, but instead of adding values, | |
-- picks the one with the highest value. | |
-- Maybe abstract out the common parts of these two. | |
keepBest varName figure@(Figure vars _) = | |
case varFromName varName vars of | |
Nothing -> Nothing | |
(Just var) -> Just $ Figure newVars finalRows | |
where | |
(Figure newVars newRows) = dropFromFigure var figure | |
groupedRows = groupByProperty mappings newRows | |
finalRows = map | |
(\(m, gRows) -> Row{value=maximum (map value gRows), mappings=m}) | |
groupedRows | |
sumOut :: String -> Figure -> Maybe Figure | |
sumOut varName figure@(Figure vars _) = | |
case varFromName varName vars of | |
Nothing -> Nothing | |
(Just var) -> Just $ Figure newVars finalRows | |
where | |
(Figure newVars newRows) = dropFromFigure var figure | |
groupedRows = groupByProperty mappings newRows | |
finalRows = map | |
(\(m, gRows) -> Row{value=sum (map value gRows), mappings=m}) | |
groupedRows | |
sketchyEquals :: [Var] -> Row -> Row -> Bool | |
sketchyEquals vars (Row val1 map1) (Row val2 map2) = | |
subMap1 == subMap2 | |
where | |
subFilter = Map.filterWithKey (\var _ -> var `elem` vars) | |
subMap1 = subFilter map1 | |
subMap2 = subFilter map2 | |
validMultiply :: Figure -> Figure -> Bool | |
validMultiply (Figure vars1 _) (Figure vars2 _) = uniqueDomains | |
where | |
nameGroups = map snd $ groupByProperty name (vars1 ++ vars2) | |
domainsList = map (map domain) nameGroups | |
uniqueDomains = all ((==1) . Set.size . Set.fromList) domainsList | |
multiply :: Figure -> Figure -> Maybe Figure | |
-- If two variables have the same name, but different domains, we | |
-- will break the unique names invariant. For that reason, we only return | |
-- a Maybe Figure. | |
multiply figure1@(Figure vars1 rows1) figure2@(Figure vars2 rows2) = | |
case validMultiply figure1 figure2 of | |
True -> Just $ Figure newVars newRows | |
where | |
commonVars = Set.toList $ Set.intersection (Set.fromList vars1) (Set.fromList vars2) | |
equalRows = | |
filter (\[r1, r2] -> sketchyEquals commonVars r1 r2) $ | |
crossProduct [rows1, rows2] | |
newRows = map (\[r1, r2] -> r1 <> r2) equalRows | |
newVars = reverse $ | |
foldl (\acc var -> if var `elem` acc then acc else var:acc) | |
[] (vars1 ++ vars2) | |
False -> Nothing | |
pp :: Figure -> IO () | |
pp = putStrLn . pprintFigure |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment