Skip to content

Instantly share code, notes, and snippets.

@effectfully
Created May 13, 2018 19:05
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ViewPatterns #-}
-- This module will only define functions in Accelerate, so sometimes it is
-- useful to use the following two extensions. In particular, this allows us to
-- use the usual if-then-else syntax with Accelerate Acc and Exp terms, rather
-- than `(|?)` and `(?)` (or `acond` and `cond`) respectively.
--
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE NoImplicitPrelude #-}
module Solve where
import qualified Prelude
import Data.Array.Accelerate
import Data.Array.Accelerate.Control.Lens
-- Useful type synonyms
--
type R = Double
type Field e = Array DIM2 e
-- | Take a single time step. Repeatedly call this function until we get to the
-- final solution.
--
-- This corresponds to the main loop of the program in 'c/main.c' line 155.
--
step :: Acc (Scalar R)
-> Acc (Field R)
-> Acc (Scalar Bool, Field R)
step dt x0 =
let
-- maximum number of iterations for convergence
maxiters = 50
-- convergence tolerance
tolerance = 1.0e-3
-- helper constants
yes = unit (constant True)
no = unit (constant False)
-- Initial conditions for the loop
--
-- This consists of a flag indicating that the solution has converged (and
-- we can stop), the number of iterations we have taken so far, and the
-- current solution.
--
start :: Acc (Scalar Bool, Scalar Int, Field R)
start = lift (no, unit 0, x0)
-- Loop condition
--
-- The loop will keep executing while this returns True. We stop once the
-- solver has converged, or we reach the iteration limit.
--
cont :: Acc (Scalar Bool, Scalar Int, Field R)
-> Acc (Scalar Bool)
cont state = zipWith (\converged i -> not converged && i < maxiters) (state ^. _1) (state ^. _2)
-- Loop body
--
body :: Acc (Scalar Bool, Scalar Int, Field R)
-> Acc (Scalar Bool, Scalar Int, Field R)
body state =
let
-- Unpack the loop-carried state variables.
--
-- In Accelerate we can use the functions `lift` and `unlift` to
-- push the `Exp` and `Acc` type through constructors, such as
-- tuples. That is, we use `unlift` here to convert our Acc-tuple
-- into a tuple-of-Acc. `lift` is the inverse.
--
-- Note that, because we do not use one of the results of the
-- `unlift`, the type checker will complain that it does not know
-- what the type of that element should be; thus, we needed to add
-- an explicit type signature to the result of the `unlift`.
--
-- Another approach (and as seen in the `cont` function) is to use
-- lenses to access each component of the structure. These are
-- provided by the `lens-accelerate` package. Example:
--
-- > it = state ^. _2
-- > x_new = state ^. _3
--
(_, it, x_new) = unlift state :: (Acc (Scalar Bool), Acc (Scalar Int), Acc (Field R))
-- Compute residual
--
b = diffusion dt x0 x_new
residual = norm2 (flatten b)
-- Check for convergence
--
-- The RebindableSyntax extension lets us use the regular if-then-else
-- syntax for Accelerate terms, at both the `Acc` and `Exp` level. This
-- gets converted into the regular function calls `acond` and `cond`
-- respectively (there are also infix versions of these functions;
-- `(?|)` and `(?)`). Check the type of those operators to see what the
-- if-then-else block expects.
--
-- Notice that the branches use the `lift` function to tuple-up their
-- results; this converts the tuple-of-Acc into an Acc-tuple. In this
-- instance, we have that;
--
-- > lift :: (Acc (Scalar Bool), Acc (Scalar Int), Acc (Field R)) -> Acc (Scalar Bool, Scalar Int, Field R)
--
in
if the residual < tolerance
then lift (yes, it, x_new) -- solution converged; exit with success
else
let -- Solve linear system to get -delta_x
--
-- This also tells us (via `ok`) whether the solver converged.
-- If it did not, then we should also exit this loop to indicate
-- that the solution has failed.
--
-- Notice that we didn't have to add a type signature to
-- `unlift` here because we use every component of the result,
-- so, GHC can figure out what the types should be.
--
(ok, dx) = unlift $ solve dt x0 b
x_new' = zipWith (-) x_new dx
in
if the ok
then lift (no, map (+1) it, x_new') -- keep iterating to solution
else lift (no, unit maxiters, x_new') -- break out of the loop
-- solver loop
result = awhile cont body start
in
lift (result ^. _1, result ^. _3)
-- | Conjugate gradient solver routine.
--
-- Solve the linear system \( A * x = b \) for @x@.
--
-- The matrix A is implicit in the objective function for the diffusion
-- equation. The input @x@ is used as the initial guess at the solution.
--
-- This corresponds to the 'ss_cg' function in 'c/linalg.c' line 169.
--
solve :: Acc (Scalar R)
-> Acc (Field R)
-> Acc (Field R)
-> Acc (Scalar Bool, Field R)
solve dt x0 b0 =
let
-- epsilon value used for matrix-vector approximation
eps = 1.0e-8
eps_inv = 1.0 / eps
-- maximum number of iterations for convergence
maxiters :: Exp Int
maxiters = 200
-- convergence tolerance
tolerance = 1.0e-3
-- matrix-vector multiplication is approximated with:
--
-- A*v = 1/epsilon * ( F( x+epsilon*v ) - F(x) )
-- = 1/epsilon * ( F( x+epsilon*v ) - Fx_old )
--
fx_old = diffusion dt x0 x0
v0 = map ((1+eps) *) x0
fx' = diffusion dt x0 v0
r0 = zipWith3 (\b l r -> b - eps_inv * (l-r)) b0 fx' fx_old
rnew0 = let r0_ = flatten r0 in dot r0_ r0_
converged0 = map (\u -> sqrt u < tolerance) rnew0
-- initial conditions for the solver loop
start :: Acc (Scalar Bool, Scalar Int, Field R, Field R, Field R, Scalar R)
start = lift (converged0, unit 1, r0, r0, x0, rnew0)
-- TODO: solver routine (c/linalg.c:223)
result :: Acc (Scalar Bool, Scalar Int, Field R, Field R, Field R, Scalar R)
result = go start where
go acc =
if the converged
then acc
else if the i > maxiters
then error $ Prelude.concat
[ "achieved: ", Prelude.show (map sqrt r)
, "\n"
, "recuired: ", Prelude.show tolerance
]
else go acc'
where
(converged, i, p, r, x, rold) = unlift acc ::
(Acc (Scalar Bool), Acc (Scalar Int), Acc (Field R), Acc (Field R), Acc (Field R), Acc (Scalar R))
i' = map (+ 1) i
-- Ap = A * p
v = zipWith (\l r -> l + eps * r) x0 p
fx' = diffusion dt x v
ap = zipWith (\l r -> eps_inv * (l - r)) fx' fx_old
-- alpha = rold / p * Ap
alpha = the rold / the (dot (flatten p) (flatten ap))
-- x += alpha * p
x' = zipWith (\l r -> l + alpha * r) x p
-- r -= alpha * Ap
r' = zipWith (\l r -> l - alpha * r) r ap
-- rnew = ss_dot(r, r, N);
rnew = let r'_ = flatten r' in dot r'_ r'_
-- p = r + rnew / rold * p
p' = zipWith (\l r -> l + the rnew / the rold * r) r' p
converged' = map (\u -> sqrt u < tolerance) rnew
acc' = lift (converged', i', p', r', x', rnew)
in
lift (result ^. _1, result ^. _5)
-- | Compute the inner product of two vectors @x@ and @y@.
--
dot :: Num e => Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
dot xs ys = fold (+) 0 (zipWith (*) xs ys)
-- | Compute the L^2 norm (Euclidean distance) of the given vector
--
norm2 :: Floating e => Acc (Vector e) -> Acc (Scalar e)
norm2 xs = map sqrt (dot xs xs)
-- | The problem is over a rectangular grid of (nx * ny) points. A finite volume
-- discritisation and method of lines gives the following ordinary differential
-- equation for each grid point:
--
-- \[
-- \frac{d s_{ij}}{d t} = \frac{D}{\Delta x^2} [ -4s_{ij} + s_{i-1,j} + s_{i+1,j} + s_{i,j-1} + s_{i,j+1} ] + R s_{ij}(1 - s_{ij})
-- \]
--
-- Which is expressed as the following nonlinear problem:
--
-- \[
-- f_{ij} = [ -(4 + \alpha)s_{ij} + s_{i-1,j} + s_{i+1,j} + s_{i,j-1} + s_{i,j+1} + \beta s_{ij}(1 - s_{ij}) ]^{k+1} + \alpha s_{ij}^k = 0
-- \]
--
-- where s^k is the field at the previous timestep, and s^(k+1) is the current
-- guess of the field at the new time step.
--
-- This corresponds to the 'diffusion' function in 'c/operators.c'. Note that
-- the C implementation inputs the previous grid solution (which I called x0
-- here, and they call X (a macro for x_old)) via a global variable.
--
diffusion
:: Acc (Scalar R) -- time step
-> Acc (Field R) -- solution at time k
-> Acc (Field R) -- current approximation of solution at time (k+1)
-> Acc (Field R) -- new approximation
diffusion (the -> dt) x0 x1 = stencil2 f dirichlet x0 dirichlet x1
where
-- Accelerate uses nested tuples to specify stencil patterns, which is
-- particularly useful for 1D and 2D stencils as they visually represent
-- which elements of the array are being accessed. Note that we can use as
-- many or as few of the elements of the stencil as we wish.
--
-- Roman: seems correct too...
f :: Stencil3x3 R -> Stencil3x3 R -> Exp R
f (_, (_, xzz, _), _)
( (_ , upz, _ )
, (uzp, uzz, uzs)
, (_ , usz, _ )
) = -(4 + alpha) * uzz + uzp + uzs + upz + usz +
alpha * xzz + beta * uzz * (1 - uzz)
-- For the stencil operator, we need to specify what happens when the
-- neighbouring points read by the stencil are out-of-bounds.
--
-- For this problem we just set all out-of-bounds points to zero. The
-- function we use here is given out-of-bounds index, which you can examine
-- to determine which boundary it is. In this way you can, for example,
-- specify different boundary conditions along different edges, or read in
-- the boundary values from a separate array.
--
-- Other inbuilt boundary conditions include 'wrap', 'mirror', and 'clamp'.
--
dirichlet :: Boundary (Field R)
dirichlet = function (\_ -> 0)
-- simulation constants
Z :. ny :. nx = unlift (shape x0)
dx = 1.0 / (fromIntegral (nx - 1))
dy = 1.0 / (fromIntegral (ny - 1))
dxy = dx * dy
alpha = dxy / (1.0 * dt) -- diffusion coefficient D = 1.0
-- Roman: `beta` is `dxs` from the C code, but they're defined differently.
beta = 1000.0 * dxy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment