Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@gelisam
Created March 10, 2019 17:47
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gelisam/64674fd783584ef8eeab7a11bb72c2a5 to your computer and use it in GitHub Desktop.
Save gelisam/64674fd783584ef8eeab7a11bb72c2a5 to your computer and use it in GitHub Desktop.
dynamic programming using recursion schemes
-- Solving a dynamic programming in many ways, including using existing
-- recursion schemes and by defining new ones. The problem of solving this
-- particular problem using recursion schemes was posed by Sandy Maguire.
{-# LANGUAGE FlexibleContexts, RankNTypes, TypeApplications, TypeFamilies, ScopedTypeVariables #-}
{-# OPTIONS -Wno-orphans #-}
module Dyna where
import Test.DocTest
import Data.Functor.Foldable (Base, Fix, Recursive(project), Corecursive(embed, ana), hylo, cataA)
import Control.Comonad.Trans.Env (EnvT(..), ask)
import Control.Monad.State (State, get, modify, evalState)
import Data.Array (Array, (!))
import Data.Ix (Ix(inRange))
import Data.List (maximumBy, transpose)
import Data.Map (Map)
import Data.Maybe (catMaybes, fromMaybe, listToMaybe)
import Data.Ord (comparing)
import Data.Tree (Tree(Node))
import qualified Data.Array as Array
import qualified Data.Map as Map
--------------------------------------------------------------------------------
-- type synonyms --
--------------------------------------------------------------------------------
type Pos = (Int,Int)
type Path = [Int]
type Result = Maybe Path -- Nothing if there are no decreasing paths to the goal
type Cache = Map Pos Result
--------------------------------------------------------------------------------
-- Array helpers --
--------------------------------------------------------------------------------
-- | partial if the [[a]] isn't rectangular
--
-- >>> array2D []
-- array ((1,1),(0,0)) []
-- >>> array2D [[],[],[]]
-- array ((1,1),(0,3)) []
-- >>> array2D [[1,2],[3,4],[5,6]]
-- array ((1,1),(2,3)) [((1,1),1),((1,2),3),((1,3),5),((2,1),2),((2,2),4),((2,3),6)]
array2D :: [[a]] -> Array Pos a
array2D [] = Array.listArray ((1,1), (0,0)) []
array2D xss = Array.listArray ((1,1), (w,h))
$ concat
$ transpose
$ xss
where
w = length $ fromMaybe [] $ listToMaybe xss
h = length xss
inArray :: Ix i
=> Array i a -> i -> Bool
inArray = inRange . Array.bounds
--------------------------------------------------------------------------------
-- Decreasing paths --
--------------------------------------------------------------------------------
-- The task is to find the longest decreasing path from one corner of a grid to
-- the other. "Path" means a sequence of cells with are next to each other,
-- either vertically or horizontally. "Decreasing" means that the contents of
-- each cell is strictly smaller than the contents of the previous cell along
-- the path.
-- Instead of parameterizing everything by a grid, we hardcode the grid here in
-- order to make the code shorter. Note that the longest path moves in all
-- cardinal directions, so unlike with the longest-common-subsequence problem,
-- the sub-problems aren't always to the bottom and to the right of the current
-- position. If they were, we could use a 'histo' on the Peano encoding of the
-- number of positions, as that would give us access to the answers to all our
-- sub-problems. But this problem is harder, so we will use memoization instead.
-- (Alternatively, we could have sorted the positions by their contents, and
-- solved the positions with the smaller contents first)
grid :: Array Pos Int
grid = array2D
[ [30,29,28,27,26]
, [ 9,10,11,12,25]
, [ 8,17,16,13,24]
, [ 7,18,15,14,23]
, [ 6,19,20,21,22]
, [ 5, 4, 3, 2, 1]
]
start, goal :: Pos
(start, goal) = Array.bounds grid
decreasing :: Pos -> Pos -> Bool
decreasing ij ij' = (grid ! ij) > (grid ! ij')
neighbours :: Pos -> [Pos]
neighbours (i,j) = filter (decreasing (i,j))
$ filter (inArray grid)
$ [(i-1,j), (i+1,j), (i,j-1), (i,j+1)]
extendResult :: Pos -> [Result] -> Result
extendResult ij results | ij == goal = Just [grid ! ij]
| otherwise = case xss of
[] -> Nothing
_ -> Just (x:xs)
where
x :: Int
x = grid ! ij
xss :: [Path]
xss = catMaybes results
-- partial if xss is empty
xs :: Path
xs = maximumBy (comparing length) xss
--------------------------------------------------------------------------------
-- Without recursion-schemes nor caching --
--------------------------------------------------------------------------------
-- | Longest Decreasing Path
--
-- >>> ldp
-- Just [30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1]
ldp :: Result
ldp = go start
where
-- Check every single path, keeping the longest one. Note that since the
-- paths are strictly-decreasing, there can be no cycles, so this will
-- terminate.
go :: Pos -> Result
go ij = extendResult ij $ fmap go $ neighbours ij
--------------------------------------------------------------------------------
-- Without recursion-schemes, with caching --
--------------------------------------------------------------------------------
-- |
-- >>> ldpCached
-- Just [30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1]
ldpCached :: Result
ldpCached = flip evalState Map.empty
$ cachedGo start
where
-- Depth-first-search, caching the longest decreasing path for the nodes we
-- have already visited.
go :: Pos -> State Cache Result
go ij = do
results <- traverse cachedGo (neighbours ij)
pure $ extendResult ij results
cachedGo :: Pos -> State Cache Result
cachedGo ij = do
cache <- get
case Map.lookup ij cache of
Just result -> pure result
Nothing -> do
result <- go ij
modify (Map.insert ij result)
pure result
--------------------------------------------------------------------------------
-- With recursion-schemes, no caching --
--------------------------------------------------------------------------------
-- ldp's recursion structure is shaped like a Rose tree, so we can use a 'hylo'
-- on that recursive type.
type TreeF a = EnvT a []
type instance Base (Tree a) = TreeF a
instance Recursive (Tree a) where
project (Node a subtrees) = EnvT a subtrees
instance Corecursive (Tree a) where
embed (EnvT a subtrees) = Node a subtrees
-- |
-- >>> ldpRs
-- Just [30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1]
ldpRs :: Result
ldpRs = hylo conquer divide start
-- or, equivalently:
-- cata conquer $ ana @(Tree Pos) divide start
divide :: Pos -> TreeF Pos Pos
divide ij = EnvT ij (neighbours ij)
conquer :: TreeF Pos Result -> Result
conquer (EnvT ij results) = extendResult ij results
--------------------------------------------------------------------------------
-- With recursion-schemes and caching --
--------------------------------------------------------------------------------
-- Instead of using 'cata' to combine results, we can use 'cata' to combine
-- 'State' computations, so that the final computation computes the desired
-- result. The recently-added 'cataA' recursion scheme guides us in that
-- direction by allowing us to choose in which order we want to run the
-- sub-computations; or, in the case of caching, whether we want to run the
-- sub-computations at all!
-- |
-- >>> ldpRsCached
-- Just [30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1]
ldpRsCached :: Result
ldpRsCached = flip evalState Map.empty
$ cataA cachedConquer
$ ana @(Tree Pos) divide
$ start
cachedConquer :: TreeF Pos (State Cache Result)
-> State Cache Result
cachedConquer (EnvT ij subcomputations) = do
cache <- get
case Map.lookup ij cache of
Just result -> do
-- note that we do not run the sub-computations!
pure result
Nothing -> do
subresults <- sequenceA subcomputations
let result = conquer (EnvT ij subresults)
modify (Map.insert ij result)
pure result
--------------------------------------------------------------------------------
-- Capturing the pattern in a new recursion scheme --
--------------------------------------------------------------------------------
-- The idea of caching the results of a 'cata' in order to skip some of the
-- sub-trees is hardly unique to this problem, so it might be useful to capture
-- it in a new recursion scheme. I should probably add it to recursion-schemes!
-- |
-- >>> ldpCachedCata
-- Just [30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1]
ldpCachedCata :: Result
ldpCachedCata = cachedCata ask sequenceA conquer
$ ana @(Tree Pos) divide
$ start
-- I generalize from @Tree k@ to some arbitrary 't', and so I need to ask the
-- caller how they want to compute the key. They are not allowed to look at the
-- sub-results as they do so.
--
-- I also ask in which order they want to run the sub-computations if the key
-- isn't found in the cache; most of the time, this will be 'sequenceA', but I
-- don't want to presume.
cachedCata :: forall t k a. (Recursive t, Ord k)
=> (forall x. Base t x -> k)
-> (forall f x. Applicative f => Base t (f x) -> f (Base t x))
-> (Base t a -> a)
-> t -> a
cachedCata getKey sequenceEffects fAlgebra = flip evalState Map.empty
. cataA go
where
go :: Base t (State (Map k a) a)
-> State (Map k a) a
go fsa = do
let k = getKey fsa
cache <- get
case Map.lookup k cache of
Just a -> pure a
Nothing -> do
fa <- sequenceEffects fsa
let a = fAlgebra fa
modify (Map.insert k a)
pure a
--------------------------------------------------------------------------------
-- Capturing dynamic programming in a new recursion scheme --
--------------------------------------------------------------------------------
-- Dynamic programming is a specific use case for caching. So we can write a
-- specialized recursion-scheme which captures the idea that we can divide a
-- problem into sub-problems, and we can cache the result for all the
-- sub-problems in order to get better performance.
--
-- This works out nicely, as I implemented 'extendResult' and 'neighbours'
-- because I wanted to reduce duplication in the other implementations, not
-- because I knew in advance that I wanted to implement a recursion scheme with
-- dyna's type!
--
-- I should probably add 'dyna' to recursion-schemes as well, but under a
-- different name, as "dynamorphism" is already an established recursion scheme
-- (which isn't provided by recursion-schemes yet).
-- |
-- >>> ldpDyna
-- Just [30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1]
ldpDyna :: Result
ldpDyna = dyna extendResult neighbours start
dyna :: forall f a b. (Traversable f, Ord a)
=> (a -> f b -> b)
-> (a -> f a)
-> a -> b
dyna solve fCoalgebra = cachedCata ask sequenceA fAlgebra
. ana @(Fix (EnvT a f)) fCoalgebra'
where
fAlgebra :: EnvT a f b -> b
fAlgebra (EnvT problem subsolutions) = solve problem subsolutions
fCoalgebra' :: a -> EnvT a f a
fCoalgebra' a = EnvT a (fCoalgebra a)
--------------------------------------------------------------------------------
-- Running the doctests --
--------------------------------------------------------------------------------
main :: IO ()
main = doctest ["src/Dyna.hs"]
@isovector
Copy link

This is wicked cool, Sam! Does any of it work if you don't have start or goal available statically?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment