-
-
Save justinlovinger/49b81dc83284732c05e4b657670b57c0 to your computer and use it in GitHub Desktop.
`inconsistent valuation @ shared 'Acc'` when trying to lift non-`Acc` function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleContexts #-} | |
module Main where | |
import qualified Data.Array.Accelerate as A | |
import qualified Data.Array.Accelerate.Interpreter | |
as A | |
import qualified Data.Array.Accelerate.System.Random.MWC | |
as MWC | |
import qualified Data.Array.Accelerate.System.Random.SFC | |
as SFC | |
type State a = A.Acc (A.Vector a, SFC.Gen) | |
data StepHyperparameters a = StepHyperparameters | |
{ sampleSize :: A.Exp Int | |
, adjustRate :: A.Exp a | |
} | |
main :: IO () | |
main = do | |
xs0 <- initialState 2 | |
print $ fst $ A.run $ step defaultStepHyperparameters liftedSumBools xs0 | |
-- | Return recommended initial state. | |
initialState | |
:: Int -- ^ Number of bits in each sample | |
-> IO (State Double) | |
initialState nb = do | |
g <- SFC.createWith . A.use <$> MWC.randomArray MWC.uniform sh | |
pure $ A.T2 (A.fill (A.constant sh) $ A.constant 0.5) g | |
where sh = A.Z A.:. nb | |
-- | Return default 'step' hyperparameters. | |
defaultStepHyperparameters | |
:: (A.Fractional a, A.Ord a) => StepHyperparameters a | |
defaultStepHyperparameters = | |
StepHyperparameters 20 0.1 | |
-- | Take 1 step towards a 'State' with a higher objective value. | |
-- by adjusting probabilities towards the best bits | |
-- in a set of samples. | |
step | |
:: (A.Num a, A.Ord a, SFC.Uniform a, A.Ord b) | |
=> StepHyperparameters a | |
-> (A.Acc (A.Vector Bool) -> A.Acc (A.Scalar b)) -- ^ Objective function. Maximize. | |
-> State a | |
-> State a | |
step (StepHyperparameters n ar) f (A.T2 ps g0) = A.T2 ps' g1 where | |
(A.T3 _ bsStar g1) = aiterate | |
(n - 1) | |
(\(A.T3 fbs bs g) -> | |
let (A.T2 bs' g') = sample ps g | |
fbs' = f bs' | |
in A.acond (A.the fbs A.< A.the fbs') (A.T3 fbs' bs' g') (A.T3 fbs bs g') | |
) | |
(let (A.T2 bs g) = sample ps g0 in A.T3 (f bs) bs g) | |
-- `adjust` from a `Probability` to a `Bit` | |
-- will always be a valid `Probability`, | |
-- because `Bit` is 0 or 1 | |
-- and `adjust` will return a value between that range. | |
ps' = adjustArray ar ps (A.map fromBool bsStar) | |
-- | Repeatedly apply a function a fixed number of times. | |
aiterate | |
:: (A.Arrays a) | |
=> A.Exp Int -- ^ number of times to apply function | |
-> (A.Acc a -> A.Acc a) -- ^ function to apply | |
-> A.Acc a -- ^ initial value | |
-> A.Acc a | |
aiterate n f xs0 = A.asnd $ A.awhile | |
(A.unit . (A.< n) . A.the . A.afst) | |
(\(A.T2 i xs) -> A.T2 (A.map (+ 1) i) (f xs)) | |
(A.lift (A.unit $ A.constant (0 :: Int), xs0)) | |
sample | |
:: (A.Ord a, SFC.Uniform a) => A.Acc (A.Vector a) -> A.Acc SFC.Gen -> A.Acc (A.Vector Bool, SFC.Gen) | |
sample ps g = A.lift (A.zipWith (A.<=) rs ps, g') | |
where (rs, g') = SFC.runRandom g SFC.randomVector | |
fromBool :: (A.Num a) => A.Exp Bool -> A.Exp a | |
fromBool x = A.cond x 1 0 | |
-- | Adjust each number in `a` to corresponding number in `b` | |
-- at given rate. | |
adjustArray | |
:: (A.Shape sh, A.Num a) | |
=> A.Exp a -- ^ Adjustment rate | |
-> A.Acc (A.Array sh a) -- ^ From | |
-> A.Acc (A.Array sh a) -- ^ To | |
-> A.Acc (A.Array sh a) | |
adjustArray rate = A.zipWith (adjust rate) | |
-- | Adjust a number from `a` to `b` | |
-- at given rate. | |
adjust | |
:: (Num a) | |
=> a -- ^ Adjustment rate | |
-> a -- ^ From | |
-> a -- ^ To | |
-> a | |
adjust rate a b = a + rate * (b - a) | |
liftedSumBools :: A.Acc (A.Vector Bool) -> A.Acc (A.Scalar Double) | |
liftedSumBools = A.use . A.fromList A.Z . (: []) . sumBools . A.toList . A.run | |
sumBools :: [Bool] -> Double | |
sumBools = sum . fmap (\b -> if b then 1 else 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment