Last active
September 23, 2022 16:35
-
-
Save ekmett/04595b489961768a3e5085ea6db979a8 to your computer and use it in GitHub Desktop.
backtracking "effects" with constraint kinds generically
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env cabal | |
{- cabal: | |
build-depends: base, constraints, ghc-prim | |
-} | |
{-# language AllowAmbiguousTypes, ConstrainedClassMethods, ConstraintKinds, | |
DefaultSignatures, FlexibleInstances, ImplicitParams, RankNTypes, | |
ScopedTypeVariables, TypeApplications, TypeFamilies, UndecidableSuperClasses #-} | |
import Control.Applicative | |
import Control.Concurrent.MVar | |
import Control.Exception | |
import Data.Constraint | |
import Data.IORef | |
import Data.Kind | |
import GHC.Base (IP(..)) | |
type Cap p = (Cap' p, p) | |
class Cap' p where | |
type Captured p :: Type | |
-- save :: Cap p => IO (Captured p) | |
save :: p => IO (Captured p) | |
-- restore :: Cap p => Captured p -> IO () | |
restore :: p => Captured p -> IO () | |
instance Cap' (IP x (IORef a)) where | |
type Captured (IP x (IORef a)) = a | |
save = readIORef (ip @x) | |
restore a = writeIORef (ip @x) a | |
-- can only capture when the mvar is not held | |
instance Cap' (IP x (MVar a)) where | |
type Captured (IP x (MVar a)) = a | |
save = readMVar (ip @x) | |
restore a = do | |
let v = ip @x | |
_ <- takeMVar v | |
putMVar v a | |
instance Cap' () where | |
type Captured () = () | |
save = pure () | |
restore = pure | |
instance (Cap' p, Cap' q) => Cap' (p, q) where | |
type Captured (p, q) = (Captured p, Captured q) | |
save = (,) <$> save @p <*> save @q | |
restore (a,b) = restore @p a *> restore @q b | |
instance (Cap' p, Cap' q, Cap' r) => Cap' (p, q, r) where | |
type Captured (p, q, r) = (Captured p, Captured q, Captured r) | |
save = (,,) <$> save @p <*> save @q <*> save @r | |
restore (a,b,c) = restore @p a *> restore @q b *> restore @r c | |
onErrorRestore :: forall p a. Cap p => IO a -> IO a | |
onErrorRestore x = bracketOnError (save @p) (restore @p) (const x) | |
orElseRestore :: forall p a. Cap p => IO a -> IO a -> IO a | |
orElseRestore x y = onErrorRestore @p x <|> y | |
tryWith :: forall p e a. (Cap p, Exception e) => IO a -> IO (Either e a) | |
tryWith x = try (onErrorRestore @p x) | |
-- run a computation and backtrack the effects, but keep the answer | |
consider :: forall p a. Cap p => IO a -> IO a | |
consider x = bracket (save @p) (restore @p) (const x) | |
type GivenFoo = ?foo :: IORef Int | |
type GivenBar = ?bar :: IORef String | |
main = do | |
fooRef <- newIORef (12 :: Int) | |
barRef <- newIORef "hello" | |
let ?foo = fooRef; ?bar = barRef | |
-- direct usage | |
p <- save @(GivenFoo,GivenBar) | |
writeIORef ?foo 34 | |
q <- save @(GivenFoo,GivenBar) | |
restore @(GivenFoo,GivenBar) p | |
r <- save @(GivenFoo,GivenBar) | |
print (p,q,r) | |
-- mask-safe usage | |
tryWith @(GivenFoo) @SomeException (writeIORef ?foo 100 *> error "wut") | |
s <- save @GivenFoo | |
print s | |
tryWith @(GivenFoo) @SomeException (writeIORef ?foo 200) | |
t <- save @GivenFoo | |
print t |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment