Skip to content

Instantly share code, notes, and snippets.

@mstksg
Created August 28, 2019 00:37
Show Gist options
  • Save mstksg/b92956c17da4026b876f5b218b9ed6e1 to your computer and use it in GitHub Desktop.
Save mstksg/b92956c17da4026b876f5b218b9ed6e1 to your computer and use it in GitHub Desktop.
backprop but using with mutable variables
#!/usr/bin/env stack
-- stack --install-ghc ghci --package ad --package lens --package vinyl --package reflection --package tagged --package transformers --package vector
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -Werror=incomplete-patterns #-}
{-# OPTIONS_GHC -Wredundant-constraints #-}
import Control.Applicative.Backwards
import Control.Exception
import Control.Lens hiding (Identity(..), Const(..))
import Control.Monad.ST
import Control.Monad.Trans.Class
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Coerce
import Data.Foldable
import Data.IORef
import Data.Kind
import Data.Proxy
import Data.Reflection
import Data.STRef
import Data.Tagged
import Data.Type.Equality
import Data.Vinyl
import Data.Vinyl.Functor
import Numeric.AD
import Numeric.AD.Internal.Reverse (Tape, Reverse)
import Numeric.AD.Rank1.Forward (Forward)
import System.IO.Unsafe
import Type.Reflection
import qualified Data.Vector as V
import qualified Data.Vinyl.Recursive as VR
data Op :: [Type] -> Type -> Type where
Op :: { opFunc :: HList as -> (a, a -> HList as)
}
-> Op as a
-- | Initialize to zero
newtype InitFunc (r :: Type -> Type) (a :: Type) = IF { runIF :: forall q. ST q (r q) }
-- | Set a ref to be one
newtype OneFunc (r :: Type -> Type) (a :: Type) = OF { runOF :: forall q. r q -> ST q () }
-- | Add a value to a ref
newtype AddFunc (r :: Type -> Type) (a :: Type) = AF { runAF :: forall q. r q -> a -> ST q () }
-- | Read a ref into a value
newtype ReadFunc (r :: Type -> Type) (a :: Type) = RF { runRF :: forall q. r q -> ST q a }
data OpR :: [Type] -> (Type -> Type, Type) -> Type where
OpR :: { opRInit :: InitFunc r a
, opROp :: Op as a
}
-> OpR as '(r, a)
data BRef (s :: Type) = BRInp !Int -- ^ Input number
| BRIx !Int -- ^ Number in tape
| BRC -- ^ no source
deriving Show
data BVar s r a = BV { _bvRef :: !(BRef s)
, _bvVal :: !a
}
forceBVar :: BVar s r a -> ()
forceBVar (BV !_ !_) = ()
data SomeF f = forall a. SomeF (TypeRep a) !(f a)
data Uncur :: (a -> b -> Type) -> (a, b) -> Type where
Uncur :: { getUncur :: !(f a b) } -> Uncur f '(a, b)
data InpRef :: Type -> (Type -> Type, Type) -> Type where
IR :: { _irVar :: !(BVar s r b)
, _irAdd :: !(AddFunc r a)
} -> InpRef a '(r, b)
-- | InputRef to some source
data SomeInpRef a = SIR (SomeF (InpRef a))
data TapeNode :: [Type] -> (Type -> Type, Type) -> Type where
TN :: { _tnInputs :: !(Rec SomeInpRef as)
, _tnGrad :: !(a -> HList as)
, _tnInit :: !(InitFunc r a)
, _tnRead :: !(ReadFunc r a)
}
-> TapeNode as '(r, a)
inspectTN :: TapeNode as ra -> String
inspectTN (TN inp _ _ _) = unlines $ VR.rfoldMap go inp
where
go :: SomeInpRef x -> [String]
go (SIR (SomeF (TupRep trx try_) (IR v _))) = [show trx, show try_, show (_bvRef v)]
newtype W = W { wRef :: IORef (Int, [SomeF (Uncur TapeNode)]) }
insertNode
:: forall a as r s. (Typeable as, Typeable r, Typeable a)
=> TapeNode as '(r, a)
-> a -- ^ val
-> W
-> IO (BVar s r a)
insertNode !tn !x !w = fmap mkVar . atomicModifyIORef (wRef w) $ \(n, t) ->
let !n' = n + 1
!t' = SomeF typeRep (Uncur tn) : t
in ((n', t'), n)
where
mkVar :: Int -> BVar s r a
mkVar i = BV (BRIx i) x
constVar :: a -> BVar s r a
constVar = BV BRC
{-# INLINE constVar #-}
type family Snds as where
Snds '[] = '[]
Snds ('(a, b) ': abs) = b ': Snds abs
-- | Project out a constant value if the 'BVar' refers to one.
bvConst :: BVar s r a -> Maybe a
bvConst (BV BRC !x) = Just x
bvConst _ = Nothing
{-# INLINE bvConst #-}
evalOp :: Op as a -> HList as -> a
evalOp o = fst . opFunc o
getSnds
:: (forall a b. f a b -> g b)
-> Rec (Uncur f) as
-> Rec g (Snds as)
getSnds f = \case
RNil -> RNil
Uncur x :& xs -> f x :& getSnds f xs
data TaggedF f a b = TaggedF { unTaggedF :: f b }
rzipWith3
:: (forall a. f a -> g a -> h a -> j a)
-> Rec f as
-> Rec g as
-> Rec h as
-> Rec j as
rzipWith3 f = \case
RNil -> \case
RNil -> \case
RNil -> RNil
x :& xs -> \case
y :& ys -> \case
z :& zs -> f x y z :& rzipWith3 f xs ys zs
liftOp_
:: forall s ras r b. (Reifies s W, Typeable b, Typeable r, Typeable ras)
=> Rec (Uncur AddFunc) ras
-> Op (Snds ras) b
-> Rec (Uncur (BVar s)) ras
-> (b -> InitFunc r b)
-> ReadFunc r b
-> IO (BVar s r b)
liftOp_ afs o vs ifb rfb = case rtraverse seekOutConst vs of
Just xs -> return . constVar $ evalOp o (getSnds (Identity . unTagged) xs)
Nothing ->
let ras = typeRepRec $ typeRep @ras
!(!y, !g) = opFunc o (getSnds (Identity . _bvVal) vs)
!tn = TN
{ _tnInputs = getSnds unTaggedF $ rzipWith3 combineAfs ras afs vs
, _tnGrad = g
, _tnInit = ifb y
, _tnRead = rfb
}
in withTypeable (recTypeRep (typeRepSnds ras)) $
insertNode tn y (reflect (Proxy @s))
where
seekOutConst :: Uncur (BVar s) x -> Maybe (Uncur Tagged x)
seekOutConst (Uncur x) = Uncur . Tagged <$> bvConst x
combineAfs
:: TypeRep x
-> Uncur AddFunc x
-> Uncur (BVar s) x
-> Uncur (TaggedF SomeInpRef) x
combineAfs tr (Uncur af) (Uncur bv) = Uncur . TaggedF . SIR $
SomeF tr $ IR bv af
-- | 'Numeric.Backprop.liftOp', but with explicit 'add' and 'zero'.
liftOp
:: forall s ras r b. (Reifies s W, Typeable b, Typeable r, Typeable ras)
=> Rec (Uncur AddFunc) ras
-> Op (Snds ras) b
-> Rec (Uncur (BVar s)) ras
-> (b -> InitFunc r b)
-> ReadFunc r b
-> BVar s r b
liftOp afs o !vs ifb = unsafePerformIO . liftOp_ afs o vs ifb
{-# INLINE liftOp #-}
liftOp1
:: (Reifies s W, Typeable ra, Typeable a, Typeable rb, Typeable b)
=> AddFunc ra a
-> Op '[a] b
-> BVar s ra a
-> (b -> InitFunc rb b)
-> ReadFunc rb b
-> BVar s rb b
liftOp1 af o v = liftOp (Uncur af :& RNil) o (Uncur v :& RNil)
liftOp2
:: (Reifies s W, Typeable ra, Typeable a, Typeable rb, Typeable b, Typeable rc, Typeable c)
=> AddFunc ra a
-> AddFunc rb b
-> Op '[a, b] c
-> BVar s ra a
-> BVar s rb b
-> (c -> InitFunc rc c)
-> ReadFunc rc c
-> BVar s rc c
liftOp2 af1 af2 o v1 v2 = liftOp (Uncur af1 :& Uncur af2 :& RNil) o (Uncur v1 :& Uncur v2 :& RNil)
partVar_
:: forall a b ra rb s. (Typeable ra, Typeable a, Typeable rb, Typeable b, Reifies s W)
=> InitFunc ra a
-> ReadFunc ra a
-> AddFunc rb a
-> (b -> a)
-> BVar s rb b
-> IO (BVar s ra a)
partVar_ ifa rfa afa getA v = insertNode tn y (reflect (Proxy @s))
where
x = _bvVal v
y = getA x
tn :: TapeNode '[a] '(ra, a)
tn = TN
{ _tnInputs = SIR (SomeF typeRep (IR v afa)) :& RNil
, _tnGrad = (:& RNil) . Identity
, _tnInit = ifa
, _tnRead = rfa
}
partVar
:: forall a b ra rb s. (Typeable ra, Typeable a, Typeable rb, Typeable b, Reifies s W)
=> InitFunc ra a
-> ReadFunc ra a
-> AddFunc rb a
-> (b -> a)
-> BVar s rb b
-> BVar s ra a
partVar ifa rfa afa getA = unsafePerformIO . partVar_ ifa rfa afa getA
viewVar_
:: (Typeable ra, Typeable a, Typeable rb, Typeable b, Reifies s W)
=> InitFunc rb b
-> InitFunc ra a
-> ReadFunc rb b
-> ReadFunc ra a
-> AddFunc rb b
-> Lens' b a
-> BVar s rb b
-> IO (BVar s ra a)
viewVar_ ifb ifa rfb rfa afa l v = partVar_ ifa rfa af (view l) v
where
af = AF $ \r x -> do
y <- runRF rfb =<< runIF ifb
runAF afa r $ set l x y
viewVar
:: (Typeable ra, Typeable a, Typeable rb, Typeable b, Reifies s W)
=> InitFunc rb b
-> InitFunc ra a
-> ReadFunc rb b
-> ReadFunc ra a
-> AddFunc rb b
-> Lens' b a
-> BVar s rb b
-> BVar s ra a
viewVar ifb ifa rfb rfa afa l = unsafePerformIO . viewVar_ ifb ifa rfb rfa afa l
initWengert :: IO W
initWengert = W <$> newIORef (0,[])
{-# INLINE initWengert #-}
fillWengert
:: forall ras rb b. ()
=> (forall s. Reifies s W => Rec (Uncur (BVar s)) ras -> BVar s rb b)
-> Rec (Uncur Tagged) ras
-> IO (V.Vector (SomeF (Uncur TapeNode)), b)
fillWengert f xs = do
w <- initWengert
reify w $ \(Proxy :: Proxy s) -> do
let !oVar = f (inpRec @s)
evaluate (forceBVar oVar)
(_, tp) <- readIORef (wRef w)
pure (V.fromList (reverse tp), _bvVal oVar)
where
inpRec :: forall s. Rec (Uncur (BVar s)) ras
inpRec = evalState (rtraverse (state . go) xs) 0
where
go :: Uncur Tagged x -> Int -> (Uncur (BVar s) x, Int)
go (Uncur x) i = (Uncur (BV (BRInp i) (unTagged x)), i + 1)
newtype RecFor s (r :: Type -> Type) (a :: Type) = RecFor (r s)
data Runner s = R { _rDelta :: !(V.Vector (SomeF (Uncur (RecFor s))))
, _rInputs :: !(V.Vector (SomeF (Uncur (RecFor s))))
}
initRunner
:: forall s. ()
=> V.Vector (SomeF (Uncur TapeNode))
-> V.Vector (SomeF (Uncur InitFunc))
-> ST s (Runner s)
initRunner stns xs =
R <$> V.mapM mkDelts stns
<*> V.mapM mkInps xs
where
mkDelts :: SomeF (Uncur TapeNode) -> ST s (SomeF (Uncur (RecFor s)))
mkDelts (SomeF (TupRep _ tr) (Uncur TN{..})) = do
r <- runIF _tnInit
pure . SomeF tr . Uncur $ RecFor r
mkInps :: SomeF (Uncur InitFunc) -> ST s (SomeF (Uncur (RecFor s)))
mkInps (SomeF tr (Uncur ifx)) = do
r <- runIF ifx
pure . SomeF tr . Uncur $ RecFor r
gradRunner
:: forall rb b s. (Typeable rb, Typeable b)
=> OneFunc rb b -- ^ set to be one
-> Runner s
-> V.Vector (SomeF (Uncur TapeNode))
-> ST s ()
gradRunner so R{..} stns = do
runOF so rO
forwards . traverse_ Backwards $ V.zipWith go _rDelta stns
where
Uncur (RecFor rO) = coerceSomeF (typeRep @'(rb, b)) "gradRunner init" (V.last _rDelta)
go :: SomeF (Uncur (RecFor s))
-> SomeF (Uncur (TapeNode))
-> ST s ()
go rf (SomeF (TupRep _ trrx) (Uncur TN{..})) = do
d <- runRF _tnRead rx
let gs = _tnGrad d
rzipWithM_ propagate _tnInputs gs
where
Uncur (RecFor rx) = coerceSomeF trrx "gradRunner tape" rf
propagate :: SomeInpRef x -> Identity x -> ST s ()
propagate (SIR (SomeF trb (IR irv ira))) (Identity x) = do
case _bvRef irv of
BRInp i ->
let Uncur (RecFor rb) = coerceSomeF trb "propagate input" (_rInputs V.! i)
in runAF ira rb x
BRIx i ->
let Uncur (RecFor rb) = coerceSomeF trb "propagate deltas" (_rDelta V.! i)
in runAF ira rb x
BRC -> return ()
coerceSomeF
:: forall a f. ()
=> TypeRep a
-> String
-> SomeF f
-> f a
coerceSomeF tra e (SomeF tr x)
| Just HRefl <- tr `eqTypeRep` tra = x
| otherwise = error $ e ++ " <" ++ show tra ++ "> vs. <" ++ show tr ++ ">"
backpropN
:: forall ras rb b. (Typeable ras, Typeable rb, Typeable b)
=> (forall s. Reifies s W => Rec (Uncur (BVar s)) ras -> BVar s rb b)
-> OneFunc rb b
-> Rec (Uncur InitFunc) ras
-> Rec (Uncur ReadFunc) ras
-> Rec (Uncur Tagged) ras
-> (b, Rec (Uncur Tagged) ras)
backpropN f sf ifs rfs !xs = (y, g')
where
!(!tp,!y) = unsafePerformIO $ fillWengert f xs
g' :: Rec (Uncur Tagged) ras
g' = runST $ do
r <- initRunner tp
. V.fromList
. VR.recordToList
. VR.rzipWith (\tr ifx -> Const (SomeF tr ifx)) (typeRepRec typeRep)
$ ifs
gradRunner sf r tp
evalStateT (rzipWithM (pullOut (_rInputs r)) (typeRepRec typeRep) rfs) 0
pullOut
:: V.Vector (SomeF (Uncur (RecFor s)))
-> TypeRep x
-> Uncur ReadFunc x
-> StateT Int (ST s) (Uncur Tagged x)
pullOut inps trx (Uncur rf) = do
i <- state $ \i' -> (i', i' + 1)
let Uncur (RecFor r) = coerceSomeF trx "pullOut" $ inps V.! i
Uncur . Tagged <$> lift (runRF rf r)
backprop
:: (Typeable ra, Typeable a, Typeable rb, Typeable b)
=> (forall s. Reifies s W => BVar s ra a -> BVar s rb b)
-> OneFunc rb b
-> InitFunc ra a
-> ReadFunc ra a
-> a
-> (b, a)
backprop f sfb ifa rfa x = second (unTagged . getUncur . rHead) $
backpropN (f . getUncur . rHead) sfb
(Uncur ifa :& RNil)
(Uncur rfa :& RNil)
(Uncur (Tagged x) :& RNil)
backprop2
:: forall a b c ra rb rc. (Typeable ra, Typeable a, Typeable rb, Typeable b, Typeable rc, Typeable c)
=> (forall s. Reifies s W => BVar s ra a -> BVar s rb b -> BVar s rc c)
-> OneFunc rc c
-> InitFunc ra a
-> InitFunc rb b
-> ReadFunc ra a
-> ReadFunc rb b
-> a
-> b
-> (c, (a, b))
backprop2 f sfc ifa ifb rfa rfb x y = second getOut $
backpropN getIn sfc
(Uncur ifa :& Uncur ifb :& RNil)
(Uncur rfa :& Uncur rfb :& RNil)
(Uncur (Tagged x) :& Uncur (Tagged y) :& RNil)
where
getOut :: Rec (Uncur Tagged) '[ '(ra, a), '(rb, b) ] -> (a, b)
getOut (Uncur (Tagged dx) :& Uncur (Tagged dy) :& RNil) = (dx, dy)
getIn :: Reifies s W => Rec (Uncur (BVar s)) '[ '(ra, a), '(rb, b) ] -> BVar s rc c
getIn (Uncur vx :& Uncur vy :& RNil) = f vx vy
gradBPN
:: forall ras rb b. (Typeable ras, Typeable rb, Typeable b)
=> (forall s. Reifies s W => Rec (Uncur (BVar s)) ras -> BVar s rb b)
-> OneFunc rb b
-> Rec (Uncur InitFunc) ras
-> Rec (Uncur ReadFunc) ras
-> Rec (Uncur Tagged) ras
-> Rec (Uncur Tagged) ras
gradBPN f sf ifs rfs = snd . backpropN f sf ifs rfs
gradBP
:: (Typeable ra, Typeable a, Typeable rb, Typeable b)
=> (forall s. Reifies s W => BVar s ra a -> BVar s rb b)
-> OneFunc rb b
-> InitFunc ra a
-> ReadFunc ra a
-> a
-> a
gradBP f sfb ifa rfa = snd . backprop f sfb ifa rfa
gradBP2
:: forall a b c ra rb rc. (Typeable ra, Typeable a, Typeable rb, Typeable b, Typeable rc, Typeable c)
=> (forall s. Reifies s W => BVar s ra a -> BVar s rb b -> BVar s rc c)
-> OneFunc rc c
-> InitFunc ra a
-> InitFunc rb b
-> ReadFunc ra a
-> ReadFunc rb b
-> a
-> b
-> (a, b)
gradBP2 f sfc ifa ifb rfa rfb x = snd . backprop2 f sfc ifa ifb rfa rfb x
main :: IO ()
main = putStrLn "hi"
op1 :: Num a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> Op '[a] a
op1 f = Op $ \(Identity x :& RNil) -> second (\dx dy -> Identity (dx * dy) :& RNil) $ diff' f x
op2 :: Num a => (forall s. Reifies s Tape => Reverse s a -> Reverse s a -> Reverse s a) -> Op '[a,a] a
op2 f = Op $ \(Identity x :& Identity y :& RNil) ->
let (z, [dX,dY]) = grad' (\[x',y'] -> f x' y') [x,y]
in (z, \dZ -> Identity (dZ * dX) :& Identity (dZ * dY) :& RNil)
newtype WholeRef a s = WR { getWR :: STRef s a }
readWR :: forall a. ReadFunc (WholeRef a) a
readWR = RF $ coerce (readSTRef @_ @a)
addWR :: forall a. Num a => AddFunc (WholeRef a) a
addWR = AF $ \(WR r) x -> modifySTRef r (+ x)
initWR :: forall a. Num a => InitFunc (WholeRef a) a
initWR = IF $ WR <$> newSTRef @a 0
oneWR :: forall a. Num a => OneFunc (WholeRef a) a
oneWR = OF $ \(WR r) -> writeSTRef r 1
instance (Num a, Typeable a, Reifies s W) => Num (BVar s (WholeRef a) a) where
x + y = liftOp2 addWR addWR (op2 (+)) x y (const initWR) readWR
x * y = liftOp2 addWR addWR (op2 (*)) x y (const initWR) readWR
x - y = liftOp2 addWR addWR (op2 (-)) x y (const initWR) readWR
negate x = liftOp1 addWR (op1 negate) x (const initWR) readWR
abs x = liftOp1 addWR (op1 abs) x (const initWR) readWR
signum x = liftOp1 addWR (op1 signum) x (const initWR) readWR
fromInteger = constVar . fromIntegral
instance (Fractional a, Typeable a, Reifies s W) => Fractional (BVar s (WholeRef a) a) where
x / y = liftOp2 addWR addWR (op2 (/)) x y (const initWR) readWR
recip x = liftOp1 addWR (op1 recip) x (const initWR) readWR
fromRational = constVar . fromRational
data TupleRef ra rb (s :: Type) = TR (ra s) (rb s)
readTR :: ReadFunc ra a -> ReadFunc rb b -> ReadFunc (TupleRef ra rb) (a, b)
readTR ra rb = RF $ \(TR rx ry) -> (,) <$> runRF ra rx <*> runRF rb ry
addTR :: AddFunc ra a -> AddFunc rb b -> AddFunc (TupleRef ra rb) (a, b)
addTR aa ab = AF $ \(TR rx ry) (x, y) -> runAF aa rx x *> runAF ab ry y
initTR :: InitFunc ra a -> InitFunc rb b -> InitFunc (TupleRef ra rb) (a, b)
initTR ia ib = IF $ TR <$> runIF ia <*> runIF ib
oneTR :: OneFunc ra a -> OneFunc rb b -> OneFunc (TupleRef ra rb) (a, b)
oneTR oa ob = OF $ \(TR rx ry) -> runOF oa rx *> runOF ob ry
fstVar
:: (Typeable ra, Typeable a, Typeable rb, Typeable b, Reifies s W)
=> InitFunc ra a
-> ReadFunc ra a
-> AddFunc ra a
-> BVar s (TupleRef ra rb) (a, b)
-> BVar s ra a
fstVar ifa rfa afa = partVar ifa rfa af fst
where
af = AF $ \(TR rx _) -> runAF afa rx
sndVar
:: (Typeable ra, Typeable a, Typeable rb, Typeable b, Reifies s W)
=> InitFunc rb b
-> ReadFunc rb b
-> AddFunc rb b
-> BVar s (TupleRef ra rb) (a, b)
-> BVar s rb b
sndVar ifa rfa afa = partVar ifa rfa af snd
where
af = AF $ \(TR _ ry) -> runAF afa ry
typeRepRec :: forall k (as :: [k]). Typeable k => TypeRep as -> Rec TypeRep as
typeRepRec tr
| Just Refl <- testEquality tr (typeRep @'[]) = RNil
| App (App c x) xs <- tr
, Just HRefl <- eqTypeRep c (typeRep @('(:) :: k -> [k] -> [k]))
= let ys = typeRepRec xs
in x :& ys
| otherwise = undefined
recTypeRep :: forall k (as :: [k]). Typeable k => Rec TypeRep as -> TypeRep as
recTypeRep = \case
RNil -> typeRep
x :& xs -> App (App (typeRep @('(:))) x) (recTypeRep xs)
data SplitTup :: (a, b) -> Type where
SplitTup :: TypeRep x -> TypeRep y -> SplitTup '(x, y)
splitTup
:: forall a b (xy :: (a, b)). (Typeable a, Typeable b)
=> TypeRep xy
-> SplitTup xy
splitTup = \case
App (App tup x) y
| Just HRefl <- eqTypeRep tup (typeRep @('(,) :: a -> b -> (a, b)))
-> SplitTup x y
_ -> errorWithoutStackTrace "what"
pattern TupRep
:: forall a b xy.
(Typeable a, Typeable b)
=> forall (x :: a) (y :: b). (xy ~ '(x, y))
=> TypeRep x
-> TypeRep y
-> TypeRep xy
pattern TupRep x y <- (splitTup->SplitTup x y)
where
TupRep x y = App (App (typeRep @'(,)) x) y
{-# COMPLETE TupRep #-}
typeRepSnds
:: forall (abs :: [(a,b)]). (Typeable a, Typeable b)
=> Rec TypeRep abs
-> Rec TypeRep (Snds abs)
typeRepSnds = \case
RNil -> RNil
TupRep _ y :& xs -> y :& typeRepSnds xs
rzipWithM_
:: forall h f g as. Applicative h
=> (forall a. f a -> g a -> h ())
-> Rec f as
-> Rec g as
-> h ()
rzipWithM_ f = go
where
go :: forall bs. Rec f bs -> Rec g bs -> h ()
go = \case
RNil -> \case
RNil -> pure ()
x :& xs -> \case
y :& ys -> f x y *> go xs ys
rzipWithM
:: forall h f g j as. Applicative h
=> (forall a. f a -> g a -> h (j a))
-> Rec f as
-> Rec g as
-> h (Rec j as)
rzipWithM f = go
where
go :: forall bs. Rec f bs -> Rec g bs -> h (Rec j bs)
go = \case
RNil -> \case
RNil -> pure RNil
x :& xs -> \case
y :& ys -> (:&) <$> f x y <*> go xs ys
rzipWithM3_
:: forall h f g j as. Applicative h
=> (forall a. f a -> g a -> j a -> h ())
-> Rec f as
-> Rec g as
-> Rec j as
-> h ()
rzipWithM3_ f = go
where
go :: forall bs. Rec f bs -> Rec g bs -> Rec j bs -> h ()
go = \case
RNil -> \case
RNil -> \case
RNil -> pure ()
x :& xs -> \case
y :& ys -> \case
z :& zs -> f x y z *> go xs ys zs
rHead :: Rec f '[a] -> f a
rHead (x :& RNil) = x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment