Skip to content

Instantly share code, notes, and snippets.

@MaxGabriel
Created November 22, 2020 20:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MaxGabriel/aae96f6f8a72d0cfb5e8f98f426e29a1 to your computer and use it in GitHub Desktop.
Save MaxGabriel/aae96f6f8a72d0cfb5e8f98f426e29a1 to your computer and use it in GitHub Desktop.
Template Haskell to load all Persistent models, stream them from the database, and validate they deserialize correctly
{-# LANGUAGE AllowAmbiguousTypes #-}
module Mercury.Database.Persist.DeriveLoadAllModels (mkLoadAllModels) where
import ClassyPrelude
import Control.Monad.Logger (MonadLogger, logInfoN)
import Data.Acquire (with)
import Data.Conduit (fuse, runConduit)
import qualified Data.Conduit.Combinators as Conduit
import qualified Data.Kind as K
import qualified Data.Text as T
import Database.Persist (DBName (..), PersistEntity, PersistEntityBackend, PersistQueryRead (..), PersistValue (..), selectSourceRes)
import Database.Persist.Sql (Filter (..), Single (..), SqlBackend, rawSql)
import Database.Persist.Types (EntityDef (..), HaskellName (..))
import Language.Haskell.TH.Syntax
import Mercury.Timing (timeAction)
skipLoadAllModelsAttribute :: Text
skipLoadAllModelsAttribute = "!skipLoadAllModels"
-- | Signature for the function generated by mkLoadAllModels
--
-- (I declare this here, so I don't have to construct this out of template-haskell primitives)
type LoadAllModelsSignature = forall (m :: K.Type -> K.Type). (MonadUnliftIO m, MonadLogger m) => (ReaderT SqlBackend m () -> m ()) -> m ()
-- | Given a function to run a transaction, runs 'processTable' from a MonadUnliftIO monad.
--
-- (This could probably be combined with 'processTable', but I lost several hours to compiler errors so just settling for this)
processTableIO :: forall record (m :: K.Type -> K.Type). (MonadLogger m, MonadUnliftIO m, PersistEntity record, PersistEntityBackend record ~ SqlBackend) => HaskellName -> DBName -> Bool -> (ReaderT SqlBackend m () -> m ()) -> m ()
processTableIO haskellName dbName shouldSkip runTrx = do
runTrx $ processTable @record haskellName dbName shouldSkip
pure ()
-- | In a transaction, do the following:
--
-- * Give an estimate of row count for a table
-- * Stream loading every row from that table, validating the model deserializes
-- * Report the results
processTable ::
forall record (m :: K.Type -> K.Type).
(MonadUnliftIO m, PersistEntity record, PersistEntityBackend record ~ SqlBackend, MonadLogger m) =>
HaskellName ->
DBName ->
Bool ->
ReaderT SqlBackend m ()
processTable haskellName tableName skip = do
let name = unHaskellName haskellName
tableNameText = unDBName tableName
if skip
then logInfoN ("Skipping " <> name <> " ; it had the" <> skipLoadAllModelsAttribute <> "attribute.")
else do
estimate <- getEstimatedRowCount tableName
logInfoN ("Starting to load rows for " <> name <> ". Estimated count: " <> tshow estimate)
-- Potential improvemement: Catch exceptions when loading, so if there are issues with multiple tables you find them all.
(actualRowCount :: Int64, seconds) <- timeAction (streamRowCount @record)
logInfoN $
T.concat
[ "Loaded ",
tshow actualRowCount,
" rows from ",
tableNameText,
" in ",
tshow seconds, -- Future improvement: format to 3 decimal places or something
" seconds."
]
-- | Gets a quick estimate of the number of rows for a given table.
--
-- This is dramatically faster than using COUNT(*). The goal is to just give the user an idea of how long it will take to load.
getEstimatedRowCount :: forall (m :: K.Type -> K.Type). (MonadIO m) => DBName -> ReaderT SqlBackend m Int64
getEstimatedRowCount tableName = do
let persistTableName = PersistText $ unDBName tableName
(pv :: [Single PersistValue]) <- rawSql "SELECT (reltuples :: bigint) FROM pg_class WHERE relname = ?" [persistTableName]
case pv of
[Single (PersistInt64 estimatedRows)] -> pure estimatedRows
unexpected -> error $ "Expected a single row containing an integer; got: " <> show unexpected
-- | Streams the contents of an entire table row-by-row.
--
-- I /think/ this will run in constant memory on the Haskell side, and is probably better for the database too.
-- In practice, I'm not sure if that works out (loading all our tables like this takes about a gigabyte, so maybe?)
streamRowCount :: forall record m. (MonadIO m, PersistEntity record, PersistEntityBackend record ~ SqlBackend) => ReaderT SqlBackend m Int64
streamRowCount = do
srcRes <- selectSourceRes ([] :: [Filter record]) []
liftIO $ with srcRes (\src -> runConduit $ src `fuse` Conduit.foldl (\prev _data -> prev + 1) 0)
-- | Creates a function of the given name, that loads every model from our database
--
-- The goal of the generated function is to check that our deserialization code is valid for every model.
-- You can ignore a given model by adding the !skipLoadAllModels attribute to the entity
--
-- An alternative approach to flagging tables to skip, is to flag tables based on how slow they are to load.
-- So e.g. QueuedJob might load at "Glacial" speed, FrontEventWebhook might load at "VerySlow" speed, and the user can choose how what threshold to run this for.
mkLoadAllModels :: String -> [EntityDef] -> Q [Dec]
mkLoadAllModels fnName entityDefs = do
let typ = ConT ''LoadAllModelsSignature
runTrxName <- newName "runTransaction"
let runTrxPat = VarP runTrxName
body <- body' runTrxName
return
[ SigD (mkName fnName) typ,
FunD (mkName fnName) [Clause [runTrxPat] (NormalB body) []]
]
where
body' :: Name -> Q Exp
body' runTrxName =
case entityDefs of
[] -> [|return ()|]
_ -> do
exps <- mapM (loadAllForTable runTrxName) entityDefs
sequence_E <- [|sequence_|]
pure $ sequence_E `AppE` ListE exps
-- | Generates a function to load all models for a given table.
loadAllForTable :: Name -> EntityDef -> Q Exp
loadAllForTable runTrxName entityDef = do
let name = entityHaskell entityDef
recordType = ConT $ mkName $ T.unpack $ unHaskellName name
shouldSkip = skipLoadAllModelsAttribute `elem` entityAttrs entityDef
-- I tried to do this all in one [| ... |] section
-- But I got the error that Type (what recordType is) is not an instance of Lift
-- Not sure if there's a workaround. I had it working before with selectList ([] :: [Filter $(return recordType)]) []
fn <- [|processTableIO|]
arg1 <- [|entityHaskell entityDef|]
arg2 <- [|entityDB entityDef|]
arg3 <- [|shouldSkip|]
let arg4 = VarE runTrxName
-- We could pass in the whole EntityDef here instead
-- I prefer to pass only what is needed, to reduce total generated codesize (EntityDef contains a lot of data)
pure $ (fn `AppTypeE` recordType) `AppE` arg1 `AppE` arg2 `AppE` arg3 `AppE` arg4
-- | Functions to help time actions
module Mercury.Timing
( StartTime (..),
getStartTime,
getElapsedSeconds,
timeAction,
)
where
import ClassyPrelude.Yesod
import Data.Ratio ((%))
import System.Clock (Clock (..), TimeSpec, diffTimeSpec, getTime, toNanoSecs)
-- | Newtype wrapper to designate a certain time as a starting time.
-- Pass this to 'getElapsedSeconds' to see how long an action took.
newtype StartTime = StartTime TimeSpec
-- | Get the current time specifically as a StartTime
getStartTime :: MonadIO m => m StartTime
getStartTime = do
-- TODO: would using CoarseMonotonic on linux be good for a speedup?
start <- liftIO $ getTime Monotonic
pure $ StartTime start
-- | Gives the time passed in seconds since the 'StartTime'
getElapsedSeconds :: MonadIO m => StartTime -> m Double
getElapsedSeconds (StartTime start) = do
-- TODO: would using CoarseMonotonic on linux be good for a speedup?
end <- liftIO $ getTime Monotonic
-- Copied from https://github.com/fimad/prometheus-haskell/blob/ec1e3d30bd59113b0184869fc12e7d6fb7251248/wai-middleware-prometheus/src/Network/Wai/Middleware/Prometheus.hs#L154
pure $ fromRational (toNanoSecs (end `diffTimeSpec` start) % 1000000000)
-- Helper function to time an action, returning the time it look to complete it in seconds
timeAction :: MonadIO m => m a -> m (a, Double)
timeAction action = do
start <- getStartTime
result <- action
duration <- getElapsedSeconds start
pure (result, duration)
share
[mkLoadAllModels "loadAllPersistentModels"]
$( persistManyFileWith
lowerCaseSettings
allModelFiles
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment