Skip to content

Instantly share code, notes, and snippets.

@sjoerdvisscher
Last active November 11, 2023 22:45
Show Gist options
  • Save sjoerdvisscher/5fe3c3cba928c4b0c112c29860894ed8 to your computer and use it in GitHub Desktop.
Save sjoerdvisscher/5fe3c3cba928c4b0c112c29860894ed8 to your computer and use it in GitHub Desktop.
Deriving differentiation with linear generics
-- https://twitter.com/paf31/status/1362207106703630338
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DataKinds #-}
import Generics.Linear hiding (D)
import Generics.Linear.TH
import Data.Functor ((<&>))
newtype D f a = D { unD :: forall x. x %1 -> (a -> x) -> f x }
deriving Functor
-- A value in the context of its one-holed context
data InContext f a = InContext a (D f a)
deriving Functor
hmap :: (forall x. f x %1 -> g x) -> InContext f a -> InContext g a
hmap n (InContext a (D k)) = InContext a (D \x ax -> n (k x ax))
class Functor f => Diff f where
contexts :: f a -> f (InContext f a)
default contexts :: (Generic1 f, Diff (Rep1 f)) => f a -> f (InContext f a)
contexts = fmap (hmap to1) . to1 . contexts . from1
instance Diff ((->) r) where
contexts f r = InContext (f r) (D \x _ _ -> x)
instance Diff Par1 where
contexts (Par1 a) =
Par1 (InContext a (D \x _ -> Par1 x))
instance (Diff f, Diff g) => Diff (f :.: g) where
contexts (Comp1 fg) = Comp1 $
contexts fg <&> \(InContext g (D kf)) ->
contexts g <&> \(InContext a (D kg)) ->
InContext a (D \x ax -> Comp1 (kf (kg x ax) (fmap ax)))
instance (Diff f, Diff g) => Diff (f :*: g) where
contexts (f :*: g) =
(contexts f <&> \(InContext a (D k)) -> InContext a (D \x ax -> k x ax :*: fmap ax g))
:*:
(contexts g <&> \(InContext a (D k)) -> InContext a (D \x ax -> fmap ax f :*: k x ax))
instance (Diff f, Diff g) => Diff (f :+: g) where
contexts (L1 f) = L1 (contexts f <&> hmap L1)
contexts (R1 f) = R1 (contexts f <&> hmap R1)
instance Diff f => Diff (M1 i c f) where
contexts (M1 f) = M1 (contexts f <&> hmap M1)
instance Diff (K1 i c) where
contexts (K1 c) = K1 c
instance Diff V1 where
contexts = \case
instance Diff U1 where
contexts U1 = U1
data Example a = Example a Bool [a]
deriving (Show, Functor)
$(deriveGeneric1 ''Example)
instance Diff []
instance Diff Example
-- ghci> plugIn <$> contexts (Example 1 True [2, 3])
-- Example (Example 1 True [2,3]) True [Example 1 True [2,3],Example 1 True [2,3]]
plugIn :: InContext f a -> f a
plugIn (InContext a dfa) = unD dfa a id
-- http://blog.sigfpe.com/2006/09/infinitesimal-types.html
-- F[x + d] = F[x] + d F'[x]
infinitesimal
:: (Diff f, Traversable f)
=> (forall void. d -> d -> void) -- d^2=0
-> f (Either d a)
-> Either (d, D f a) (f a)
infinitesimal d2void = traverse f . contexts
where
f (InContext (Right a) _) = Right a
f (InContext (Left d) (D k)) =
Left (d, D \x ax -> k x (either (d2void d) ax))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment