Skip to content

Instantly share code, notes, and snippets.

@ekmett
Created July 20, 2019 17:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ekmett/db3a17792b19f8e00c97bb49e4b54da4 to your computer and use it in GitHub Desktop.
Save ekmett/db3a17792b19f8e00c97bb49e4b54da4 to your computer and use it in GitHub Desktop.
checkpointing with 'ad'
{-# language RankNTypes #-}
{-# language FlexibleContexts #-}
{-# language ScopedTypeVariables #-}
import Numeric.AD.Internal.Identity
import Numeric.AD.Internal.Reverse
import Numeric.AD.Internal.Or
import Numeric.AD.Mode
import Numeric.AD.Mode.Reverse (grad)
import Data.Reflection (Reifies)
import Data.Foldable (toList)
import Data.Proxy
-- based on an implementation by Sofus Mortensen
grad_cp
:: forall f g a. (Traversable f, Traversable g, Num a)
=> (forall s. Reifies s Tape => f (Reverse s a) -> g (Reverse s a))
-> (forall s. Reifies s Tape => g (Reverse s a) -> Reverse s a)
-> f a -> f a
grad_cp f g xs = grad (sum . zipWith ((*) . auto) z' . toList . f) xs where
z' = toList $ grad g fxs
fxs = reifyTape 0 $ \(_ :: Proxy s) -> primal <$> f (auto <$> xs :: f (Reverse s a)) -- should have an empty tape when done
@sofusmortensen
Copy link

Nice!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment