Skip to content

Instantly share code, notes, and snippets.

@athas
Created August 9, 2016 06:24
Show Gist options
  • Save athas/5c2dd0fdb669eecce892adaa99958bd9 to your computer and use it in GitHub Desktop.
Save athas/5c2dd0fdb669eecce892adaa99958bd9 to your computer and use it in GitHub Desktop.
{-# LANGUAGE OverloadedStrings #-}
-- | This module defines an efficient value representation as well as
-- parsing and comparison functions. This is because the standard
-- Futhark parser is not able to cope with large values (like arrays
-- that are tens of megabytes in size). The representation defined
-- here does not support tuples, so don't use those as input/output
-- for your test programs.
module Futhark.Test.Values
( Value
-- * Reading Values
, readValues
-- * Comparing Values
, compareValues
, Mismatch
, explainMismatch
)
where
import Control.Monad
import Control.Arrow ((***))
import Control.Monad.ST
import qualified Data.Array as A
import Data.Maybe
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Char (isSpace)
import qualified Data.Vector.Unboxed.Mutable as UMVec
import qualified Data.Vector.Unboxed as UVec
import Data.Vector.Generic (freeze)
import qualified Data.Text as T
import Prelude
import qualified Futhark.Representation.AST.Syntax.Core as F
import Futhark.Representation.Primitive
(PrimType(..), IntType(..), FloatType(..), PrimValue)
import Language.Futhark.Parser.Lexer
import qualified Futhark.Util.Pretty as PP
import Futhark.Representation.AST.Attributes.Constants (IsValue(..))
import Futhark.Representation.AST.Pretty ()
type STVector s = UMVec.STVector s
type Vector = UVec.Vector
-- | An efficiently represented Futhark value.
data Value = Int8Value (Vector Int) (Vector Int8)
| Int16Value (Vector Int) (Vector Int16)
| Int32Value (Vector Int) (Vector Int32)
| Int64Value (Vector Int) (Vector Int64)
| Float32Value (Vector Int) (Vector Float)
| Float64Value (Vector Int) (Vector Double)
| BoolValue (Vector Int) (Vector Bool)
deriving Show
instance PP.Pretty Value where
ppr (Int8Value shape vs) = pprAsCoreValue (IntType Int8) shape vs
ppr (Int16Value shape vs) = pprAsCoreValue (IntType Int16) shape vs
ppr (Int32Value shape vs) = pprAsCoreValue (IntType Int32) shape vs
ppr (Int64Value shape vs) = pprAsCoreValue (IntType Int64) shape vs
ppr (Float32Value shape vs) = pprAsCoreValue (FloatType Float32) shape vs
ppr (Float64Value shape vs) = pprAsCoreValue (FloatType Float64) shape vs
ppr (BoolValue shape vs) = pprAsCoreValue Bool shape vs
pprAsCoreValue :: (UVec.Unbox v, IsValue v) =>
PrimType -> Vector Int -> Vector v -> PP.Doc
pprAsCoreValue bt shape vs =
PP.ppr $ F.ArrayVal (A.listArray (0, n-1) vs') bt shape'
where n = UVec.product shape
vs' = map value $ UVec.toList vs
shape' = UVec.toList shape
valueType :: Value -> PrimType
valueType (Int8Value _ _) = IntType Int8
valueType (Int16Value _ _) = IntType Int16
valueType (Int32Value _ _) = IntType Int32
valueType (Int64Value _ _) = IntType Int64
valueType (Float32Value _ _) = FloatType Float32
valueType (Float64Value _ _) = FloatType Float64
valueType (BoolValue _ _) = Bool
valueShape :: Value -> [Int]
valueShape (Int8Value shape _) = UVec.toList shape
valueShape (Int16Value shape _) = UVec.toList shape
valueShape (Int32Value shape _) = UVec.toList shape
valueShape (Int64Value shape _) = UVec.toList shape
valueShape (Float32Value shape _) = UVec.toList shape
valueShape (Float64Value shape _) = UVec.toList shape
valueShape (BoolValue shape _) = UVec.toList shape
-- The parser
dropRestOfLine, dropSpaces :: T.Text -> T.Text
dropRestOfLine = T.drop 1 . T.dropWhile (not . (=='\n'))
dropSpaces t = case T.dropWhile isSpace t of
t' | "--" `T.isPrefixOf` t' -> dropSpaces $ dropRestOfLine t'
| otherwise -> t'
type ReadValue v = T.Text -> (Maybe v, T.Text)
symbol :: Char -> T.Text -> Maybe T.Text
symbol c t
| Just (c', t') <- T.uncons t, c' == c = Just $ dropSpaces t'
| otherwise = Nothing
-- (Used elements, shape, elements, remaining input)
type State s v = (Int, Vector Int, STVector s v, T.Text)
readArrayElemsST :: UMVec.Unbox v =>
Int -> Int -> ReadValue v -> State s v
-> ST s (Maybe (Int, State s v))
readArrayElemsST j r rv s = do
ms <- readRankedArrayOfST r rv s
case ms of
Just (i, shape, arr, t)
| Just t' <- symbol ',' t ->
readArrayElemsST (j+1) r rv (i, shape, arr, t')
| otherwise -> return $ Just (j, (i, shape, arr, t))
_ ->
return $ Just (0, s)
updateShape :: Int -> Int -> Vector Int -> Maybe (Vector Int)
updateShape d n shape =
if old_n < 0
then Just $ shape UVec.// [(r-d, n)]
else if old_n == n then Just shape
else Nothing
where r = UVec.length shape
old_n = shape UVec.! (r-d)
growIfFilled :: UVec.Unbox v => Int -> STVector s v -> ST s (STVector s v)
growIfFilled i arr =
if i >= capacity
then UMVec.grow arr capacity
else return arr
where capacity = UMVec.length arr
readRankedArrayOfST :: UMVec.Unbox v =>
Int -> ReadValue v -> State s v
-> ST s (Maybe (State s v))
readRankedArrayOfST 0 rv (i, shape, arr, t)
| (Just v, t') <- rv t = do
arr' <- growIfFilled i arr
UMVec.write arr' i v
return $ Just (i+1, shape, arr', t')
readRankedArrayOfST r rv (i, shape, arr, t)
| Just t' <- symbol '[' t = do
ms <- readArrayElemsST 1 (r-1) rv (i, shape, arr, t')
return $ do
(j, s) <- ms
closeArray r j s
readRankedArrayOfST _ _ _ =
return Nothing
closeArray :: Int -> Int -> State s v -> Maybe (State s v)
closeArray r j (i, shape, arr, t) = do
t' <- symbol ']' t
shape' <- updateShape r j shape
return (i, shape', arr, t')
readRankedArrayOf :: UMVec.Unbox v =>
Int -> ReadValue v -> T.Text -> Maybe (Vector Int, Vector v, T.Text)
readRankedArrayOf r rv t = runST $ do
empty <- UMVec.new 1024
ms <- readRankedArrayOfST r rv (0, UVec.replicate r (-1), empty, t)
case ms of
Just (i, shape, arr, t') -> do
arr' <- freeze (UMVec.slice 0 i arr)
return $ Just (shape, arr', t')
Nothing ->
return Nothing
-- | A character that can be part of a value. This doesn't work for
-- string and character literals.
constituent :: Char -> Bool
constituent ',' = False
constituent ']' = False
constituent c = not $ isSpace c
readIntegral :: Integral int => (Token -> Maybe int) -> ReadValue int
readIntegral f = (lexIntegral *** dropSpaces) . T.span constituent
where lexIntegral t = case scanTokens "" t of
Right [L _ MINUS, L _ (INTLIT x)] -> Just $ negate $ fromIntegral x
Right [L _ (INTLIT x)] -> Just $ fromIntegral x
Right [L _ tok] -> f tok
Right [L _ MINUS, L _ tok] -> negate <$> f tok
_ -> Nothing
readInt8 :: ReadValue Int8
readInt8 = readIntegral f
where f (I8LIT x) = Just x
f _ = Nothing
readInt16 :: ReadValue Int16
readInt16 = readIntegral f
where f (I16LIT x) = Just x
f _ = Nothing
readInt32 :: ReadValue Int32
readInt32 = readIntegral f
where f (I32LIT x) = Just x
f _ = Nothing
readInt64 :: ReadValue Int64
readInt64 = readIntegral f
where f (I64LIT x) = Just x
f _ = Nothing
readFloat :: RealFloat float => (Token -> Maybe float) -> ReadValue float
readFloat f = (lexFloat *** dropSpaces) . T.span constituent
where fromDouble = uncurry encodeFloat . decodeFloat
lexFloat t = case scanTokens "" t of
Right [L _ MINUS, L _ (REALLIT x)] -> Just $ negate $ fromDouble x
Right [L _ (REALLIT x)] -> Just $ fromDouble x
Right [L _ tok] -> f tok
Right [L _ MINUS, L _ tok] -> negate <$> f tok
_ -> Nothing
readFloat32 :: ReadValue Float
readFloat32 = readFloat lexFloat32
where lexFloat32 (F32LIT x) = Just x
lexFloat32 _ = Nothing
readFloat64 :: ReadValue Double
readFloat64 = readFloat lexFloat64
where lexFloat64 (F64LIT x) = Just x
lexFloat64 _ = Nothing
readBool :: ReadValue Bool
readBool = (lexBool *** dropSpaces) . T.span constituent
where lexBool t = case scanTokens "" t of
Right [L _ TRUE] -> Just True
Right [L _ FALSE] -> Just False
_ -> Nothing
readValue :: T.Text -> Maybe (Value, T.Text)
readValue full_t = insideBrackets 0 full_t
where insideBrackets r t = maybe (tryValueAndReadValue r t) (insideBrackets (r+1)) $ symbol '[' t
tryWith f mk r t
| (Just _, _) <- f t = do
(shape, arr, rest_t) <- readRankedArrayOf r f full_t
return (mk shape arr, rest_t)
| otherwise = Nothing
tryValueAndReadValue r t=
tryWith readInt32 Int32Value r t `mplus`
tryWith readInt8 Int8Value r t `mplus`
tryWith readInt16 Int16Value r t `mplus`
tryWith readInt64 Int64Value r t `mplus`
tryWith readFloat64 Float64Value r t `mplus`
tryWith readFloat32 Float32Value r t `mplus`
tryWith readBool BoolValue r t
readValues :: T.Text -> Maybe [Value]
readValues = readValues' . dropSpaces
where readValues' t
| T.null t = Just []
| otherwise = do (a, t') <- readValue t
(a:) <$> readValues' t'
-- Comparisons
data Mismatch = PrimValueMismatch (Int,Int) PrimValue PrimValue
| ArrayShapeMismatch Int [Int] [Int]
| TypeMismatch Int PrimType PrimType
| ValueCountMismatch Int Int
instance Show Mismatch where
show (PrimValueMismatch (i,j) got expected) =
explainMismatch (i,j) "" got expected
show (ArrayShapeMismatch i got expected) =
explainMismatch i "array of shape " got expected
show (TypeMismatch i got expected) =
explainMismatch i "value of type " got expected
show (ValueCountMismatch got expected) =
"Expected " ++ show expected ++ " values, got " ++ show got
explainMismatch :: (Show i, PP.Pretty a) => i -> String -> a -> a -> String
explainMismatch i what got expected =
"Value " ++ show i ++ " expected " ++ what ++ PP.pretty expected ++ ", got " ++ PP.pretty got
compareValues :: [Value] -> [Value] -> Maybe Mismatch
compareValues got expected
| n /= m = Just $ ValueCountMismatch n m
| otherwise = case catMaybes $ zipWith3 compareValue [0..] got expected of
e : _ -> Just e
[] -> Nothing
where n = length got
m = length expected
compareValue :: Int -> Value -> Value -> Maybe Mismatch
compareValue i got_v expected_v
| valueShape got_v == valueShape expected_v =
case (got_v, expected_v) of
(Int8Value _ got_vs, Int8Value _ expected_vs) ->
compareNum 1 got_vs expected_vs
(Int16Value _ got_vs, Int16Value _ expected_vs) ->
compareNum 1 got_vs expected_vs
(Int32Value _ got_vs, Int32Value _ expected_vs) ->
compareNum 1 got_vs expected_vs
(Int64Value _ got_vs, Int64Value _ expected_vs) ->
compareNum 1 got_vs expected_vs
(Float32Value _ got_vs, Float32Value _ expected_vs) ->
compareNum (tolerance expected_vs) got_vs expected_vs
(Float64Value _ got_vs, Float64Value _ expected_vs) ->
compareNum (tolerance expected_vs) got_vs expected_vs
(BoolValue _ got_vs, BoolValue _ expected_vs) ->
compareGen compareBool got_vs expected_vs
_ ->
Just $ TypeMismatch i (valueType got_v) (valueType expected_v)
| otherwise =
Just $ ArrayShapeMismatch i (valueShape got_v) (valueShape expected_v)
where compareNum tol = compareGen $ compareElement tol
compareGen cmp got expected =
foldl mplus Nothing $
zipWith cmp (UVec.toList $ UVec.indexed got) (UVec.toList expected)
compareElement tol (j, got) expected
| comparePrimValue tol got expected = Nothing
| otherwise = Just $ PrimValueMismatch (i,j) (value got) (value expected)
compareBool (j, got) expected
| got == expected = Nothing
| otherwise = Just $ PrimValueMismatch (i,j) (value got) (value expected)
comparePrimValue :: (Ord num, Num num) =>
num -> num -> num -> Bool
comparePrimValue tol x y =
diff < tol
where diff = abs $ x - y
minTolerance :: Fractional a => a
minTolerance = 0.002 -- 0.2%
tolerance :: (Ord a, Fractional a, UVec.Unbox a) => Vector a -> a
tolerance = UVec.foldl tolerance' minTolerance
where tolerance' t v = max t $ minTolerance * v
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment