Skip to content

Instantly share code, notes, and snippets.

@ejconlon
Created November 5, 2022 17:32
Show Gist options
  • Save ejconlon/d5ac58666d1c9fcaa6ced56f19f87f0f to your computer and use it in GitHub Desktop.
Save ejconlon/d5ac58666d1c9fcaa6ced56f19f87f0f to your computer and use it in GitHub Desktop.
Backtracking state with LogicT
-- | 'LogicT' is a great monad transformer for backtracking control,
-- but if you just layer with a 'State' monad, you won't backtrack state.
-- By that I mean at all choice points '(<|>)' or 'interleave', we will
-- save part of the state and reset it when retrying alternative branches.
module BacktrackingStateSearch
( TrackSt (..)
, Track
, observeManyTrack
, runManyTrack
) where
import Control.Applicative (Alternative (..))
import Control.Monad.Logic (LogicT, MonadLogic (..), observeManyT)
import Control.Monad.State.Strict (MonadState (..), State, gets, modify', runState)
import Data.Bifunctor (second)
-- | Backtracking state - the x component goes forward, the y component backtracks
-- All mentions of state below are really about the backtracking state component.
-- The forward state component is pretty boring.
data TrackSt x y = TrackSt
{ tsFwd :: !x
, tsBwd :: !y
} deriving stock (Eq, Show)
-- | Backtracking search monad. Take care not to expose the constructor!
-- The major issue with backtracking is that the final state is that of
-- the last branch that has executed. In order for the 'msplit' law to hold
-- (`msplit m >>= reflect = m`) we have to ensure that the same state
-- is observable on all exit points. Basically the only way to do this is to
-- not make the state visible at all externally, which requires that we
-- protect the constructor here and only allow elimination of this type
-- with 'observeManyTrack', which resets the state for us.
newtype Track x y a = Track { unTrack :: LogicT (State (TrackSt x y)) a }
deriving newtype (Functor, Applicative, Monad, MonadState (TrackSt x y))
-- | Wraps logict's 'observeManyT' and forces us to 'reset' the backtracking state.
observeManyTrack :: Int -> Track x y a -> State (TrackSt x y) [a]
observeManyTrack n = observeManyT n . unTrack . reset
-- | A nicer way to run the search.
runManyTrack :: Int -> Track x y a -> TrackSt x y -> ([a], TrackSt x y)
runManyTrack n m = runState (observeManyTrack n m)
-- | At many points below we'll need to restore a saved state before
-- continuing the search.
restore :: y -> Track x y a -> Track x y a
restore saved x = modify' (\st -> st { tsBwd = saved }) *> x
-- | Restores the backtracked state after all results have been enumerated.
finalize :: y -> Track x y a -> Track x y a
finalize saved x = Track (unTrack x <|> unTrack (restore saved empty))
-- | Ensures the backtrack state is returned to the current state.
-- This is run on the outside of the search so the backtracked state is
-- not externally observable.
reset :: Track x y a -> Track x y a
reset x = do
saved <- gets tsBwd
finalize saved x
instance Alternative (Track x y) where
empty = Track empty
x <|> y = do
saved <- gets tsBwd
-- Restore the current state before going down the right branch.
Track (unTrack x <|> unTrack (restore saved y))
instance MonadLogic (Track x y) where
-- This is just newtype noise - we have to define this, but we really
-- need to override 'interleave'. (Unless I missed a case, I don't think
-- we need to reset in the tail...)
msplit x = Track (fmap (fmap (second Track)) (msplit (unTrack x)))
interleave x y = do
saved <- gets tsBwd
-- Again restore the current state before going down the right branch.
Track (interleave (unTrack x) (unTrack (restore saved y)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment