Created
February 2, 2015 19:03
-
-
Save ajtulloch/46e8ca8a3a8b723f307c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module SDP where | |
import Control.Applicative | |
import Data.Function | |
import Data.List | |
import Test.QuickCheck | |
newtype Index = Index Int deriving (Eq, Ord, Show) | |
newtype Value a = Value a deriving (Eq, Ord, Show) | |
newtype SV a = SV [(Index, Value a)] deriving (Show) | |
sdp :: Num a => SV a -> SV a -> a | |
sdp (SV xs) (SV ys) = go xs ys 0 | |
where | |
go [] _ acc = acc | |
go _ [] acc = acc | |
go xa@((Index x, Value xv):xrest) ya@((Index y, Value yv):yrest) acc | |
| x == y = go xrest yrest (acc + xv * yv) | |
| x > y = go xa yrest acc | |
| x < y = go xrest ya acc | |
stod :: (Num a) => Int -> SV a -> [a] | |
stod n (SV xs) = foldl insertValue (replicate n 0) xs | |
where | |
insertValue dv (Index x, Value v) = update x v dv | |
where | |
update i a as = zipWith repl as [0..] where | |
repl a' i' | i == i' = a | |
| otherwise = a' | |
stod' :: (Num a) => SV a -> [a] | |
stod' sv = stod (maxV sv + 1) sv | |
where | |
maxV (SV []) = 0 | |
maxV (SV xs) = (maximum . map (\(Index i, _) -> i)) xs | |
instance (Arbitrary a) => Arbitrary (Value a) where | |
arbitrary = Value <$> arbitrary | |
instance Arbitrary Index where | |
arbitrary = do | |
Positive v <- arbitrary | |
return $ Index v | |
instance (Arbitrary a) => Arbitrary (SV a) where | |
arbitrary = (SV . uniq . sorted) <$> arbitrary | |
where | |
sorted = sortBy (compare `on` fst) | |
uniq = nubBy ((==) `on` fst) | |
prop_stod :: Num a => Positive Int -> SV a -> Bool | |
prop_stod (Positive n) sv = length (stod n sv) == n | |
prop_sdp2 :: (Num a, Eq a) => SV a -> Bool | |
prop_sdp2 sv = sdp sv sv == l2sq sv where l2sq = sum . map (\x -> x * x) . stod' | |
prop_sdp :: (Num a, Eq a) => SV a -> SV a -> Bool | |
prop_sdp sv sv' = sdp sv sv' == sum (zipWith (*) (stod' sv) (stod' sv')) | |
main :: IO () | |
main = do | |
print $ sdp (SV [(Index 1, Value 2 :: Value Float)]) $ SV [(Index 2, Value 3)] | |
print $ sdp (SV [(Index 2, Value 2 :: Value Float)]) $ SV [(Index 2, Value 3)] | |
quickCheck (prop_sdp2 :: SV Int -> Bool) | |
quickCheck (prop_stod :: Positive Int -> SV Int -> Bool) | |
quickCheck (prop_sdp :: SV Int -> SV Int -> Bool) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment