Skip to content

Instantly share code, notes, and snippets.

@justinlovinger
Created February 9, 2021 22:24
Show Gist options
  • Save justinlovinger/49b81dc83284732c05e4b657670b57c0 to your computer and use it in GitHub Desktop.
Save justinlovinger/49b81dc83284732c05e4b657670b57c0 to your computer and use it in GitHub Desktop.
`inconsistent valuation @ shared 'Acc'` when trying to lift non-`Acc` function
{-# LANGUAGE FlexibleContexts #-}
module Main where
import qualified Data.Array.Accelerate as A
import qualified Data.Array.Accelerate.Interpreter
as A
import qualified Data.Array.Accelerate.System.Random.MWC
as MWC
import qualified Data.Array.Accelerate.System.Random.SFC
as SFC
type State a = A.Acc (A.Vector a, SFC.Gen)
data StepHyperparameters a = StepHyperparameters
{ sampleSize :: A.Exp Int
, adjustRate :: A.Exp a
}
main :: IO ()
main = do
xs0 <- initialState 2
print $ fst $ A.run $ step defaultStepHyperparameters liftedSumBools xs0
-- | Return recommended initial state.
initialState
:: Int -- ^ Number of bits in each sample
-> IO (State Double)
initialState nb = do
g <- SFC.createWith . A.use <$> MWC.randomArray MWC.uniform sh
pure $ A.T2 (A.fill (A.constant sh) $ A.constant 0.5) g
where sh = A.Z A.:. nb
-- | Return default 'step' hyperparameters.
defaultStepHyperparameters
:: (A.Fractional a, A.Ord a) => StepHyperparameters a
defaultStepHyperparameters =
StepHyperparameters 20 0.1
-- | Take 1 step towards a 'State' with a higher objective value.
-- by adjusting probabilities towards the best bits
-- in a set of samples.
step
:: (A.Num a, A.Ord a, SFC.Uniform a, A.Ord b)
=> StepHyperparameters a
-> (A.Acc (A.Vector Bool) -> A.Acc (A.Scalar b)) -- ^ Objective function. Maximize.
-> State a
-> State a
step (StepHyperparameters n ar) f (A.T2 ps g0) = A.T2 ps' g1 where
(A.T3 _ bsStar g1) = aiterate
(n - 1)
(\(A.T3 fbs bs g) ->
let (A.T2 bs' g') = sample ps g
fbs' = f bs'
in A.acond (A.the fbs A.< A.the fbs') (A.T3 fbs' bs' g') (A.T3 fbs bs g')
)
(let (A.T2 bs g) = sample ps g0 in A.T3 (f bs) bs g)
-- `adjust` from a `Probability` to a `Bit`
-- will always be a valid `Probability`,
-- because `Bit` is 0 or 1
-- and `adjust` will return a value between that range.
ps' = adjustArray ar ps (A.map fromBool bsStar)
-- | Repeatedly apply a function a fixed number of times.
aiterate
:: (A.Arrays a)
=> A.Exp Int -- ^ number of times to apply function
-> (A.Acc a -> A.Acc a) -- ^ function to apply
-> A.Acc a -- ^ initial value
-> A.Acc a
aiterate n f xs0 = A.asnd $ A.awhile
(A.unit . (A.< n) . A.the . A.afst)
(\(A.T2 i xs) -> A.T2 (A.map (+ 1) i) (f xs))
(A.lift (A.unit $ A.constant (0 :: Int), xs0))
sample
:: (A.Ord a, SFC.Uniform a) => A.Acc (A.Vector a) -> A.Acc SFC.Gen -> A.Acc (A.Vector Bool, SFC.Gen)
sample ps g = A.lift (A.zipWith (A.<=) rs ps, g')
where (rs, g') = SFC.runRandom g SFC.randomVector
fromBool :: (A.Num a) => A.Exp Bool -> A.Exp a
fromBool x = A.cond x 1 0
-- | Adjust each number in `a` to corresponding number in `b`
-- at given rate.
adjustArray
:: (A.Shape sh, A.Num a)
=> A.Exp a -- ^ Adjustment rate
-> A.Acc (A.Array sh a) -- ^ From
-> A.Acc (A.Array sh a) -- ^ To
-> A.Acc (A.Array sh a)
adjustArray rate = A.zipWith (adjust rate)
-- | Adjust a number from `a` to `b`
-- at given rate.
adjust
:: (Num a)
=> a -- ^ Adjustment rate
-> a -- ^ From
-> a -- ^ To
-> a
adjust rate a b = a + rate * (b - a)
liftedSumBools :: A.Acc (A.Vector Bool) -> A.Acc (A.Scalar Double)
liftedSumBools = A.use . A.fromList A.Z . (: []) . sumBools . A.toList . A.run
sumBools :: [Bool] -> Double
sumBools = sum . fmap (\b -> if b then 1 else 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment