Skip to content

Instantly share code, notes, and snippets.

@KaneTW
Created May 1, 2024 13:05
Show Gist options
  • Save KaneTW/c223d4da042f5d66e41ee27a46bcefc5 to your computer and use it in GitHub Desktop.
Save KaneTW/c223d4da042f5d66e41ee27a46bcefc5 to your computer and use it in GitHub Desktop.
{-# language ApplicativeDo #-}
{-# language LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE InstanceSigs #-}
-- not sure if this is all extensions or all these extensions are needed, just pasting from the file
module VerifySchema where
import Prelude hiding ( filter )
import qualified Data.List as L
import Rel8 hiding (run)
import qualified Rel8
import Hasql.Connection
import Hasql.Session
import Data.Functor.Contravariant ( (>$<) )
import Data.Int ( Int64 )
import Data.Text ( Text )
import qualified Data.Text as T
import GHC.Generics
import qualified Data.List.NonEmpty as NonEmpty
import qualified Data.Map as M
import Rel8.Schema.Null hiding (nullable)
import qualified Rel8.Schema.Null as Null
import Rel8.Schema.Name ( Name(Name) )
import Rel8.Schema.Spec
import Data.Functor.Const
import Rel8.Schema.HTable
import Control.Monad
import Control.Monad.Accum
import Control.Monad.Reader
data Relkind = RTable
deriving stock (Show)
deriving anyclass (DBEq)
instance DBType Relkind where
typeInformation = parseTypeInformation parser printer typeInformation
where
parser = \case
"r" -> pure RTable
(x :: Text) -> Left $ "Unknown relkind: " ++ show x
printer = \case
RTable -> "r"
newtype Oid = Oid Int64
deriving newtype (DBType, DBEq, Show)
data PGClass f = PGClass
{ oid :: Column f Oid
, relname :: Column f Text
, relkind :: Column f Relkind
, relnamespace :: Column f Oid
}
deriving stock (Generic)
deriving anyclass (Rel8able)
deriving stock instance Show (PGClass Result)
pgclass :: TableSchema (PGClass Name)
pgclass = TableSchema
{ name = QualifiedName "pg_class" (Just "pg_catalog")
, columns = namesFromLabelsWith NonEmpty.last
}
data PGAttribute f = PGAttribute
{ attrelid :: Column f Oid
, attname :: Column f Text
, atttypid :: Column f Oid
, attnum :: Column f Int64
, attnotnull :: Column f Bool
}
deriving stock (Generic)
deriving anyclass (Rel8able)
deriving stock instance Show (PGAttribute Result)
pgattribute :: TableSchema (PGAttribute Name)
pgattribute = TableSchema
{ name = QualifiedName "pg_attribute" (Just "pg_catalog")
, columns = namesFromLabelsWith NonEmpty.last
}
data PGType f = PGType
{ oid :: Column f Oid
, typname :: Column f Text
}
deriving stock (Generic)
deriving anyclass (Rel8able)
deriving stock instance Show (PGType Result)
pgtype :: TableSchema (PGType Name)
pgtype = TableSchema
{ name = QualifiedName "pg_type" (Just "pg_catalog")
, columns = namesFromLabelsWith NonEmpty.last
}
data PGNamespace f = PGNamespace
{ oid :: Column f Oid
, nspname :: Column f Text
}
deriving stock (Generic)
deriving anyclass (Rel8able)
deriving stock instance Show (PGNamespace Result)
pgnamespace :: TableSchema (PGNamespace Name)
pgnamespace = TableSchema
{ name = QualifiedName "pg_namespace" (Just "pg_catalog")
, columns = namesFromLabelsWith NonEmpty.last
}
data PGCast f = PGCast
{ oid :: Column f Oid
, castsource :: Column f Oid
, casttarget :: Column f Oid
, castfunc :: Column f Oid
, castcontext :: Column f Char
, castmethod :: Column f Char
}
deriving stock (Generic)
deriving anyclass (Rel8able)
deriving stock instance Show (PGCast Result)
pgcast :: TableSchema (PGCast Name)
pgcast = TableSchema
{ name = QualifiedName "pg_cast" (Just "pg_catalog")
, columns = namesFromLabelsWith NonEmpty.last
}
data PGTable f = PGTable
{ name :: Column f Text
, columns :: HList f (Attribute f)
}
deriving stock (Generic)
deriving anyclass (Rel8able)
deriving stock instance Show (PGTable Result)
data Attribute f = Attribute
{ attribute :: PGAttribute f
, typ :: PGType f
}
deriving stock (Generic)
deriving anyclass (Rel8able)
deriving stock instance Show (Attribute Result)
data Cast f = Cast
{ source :: PGType f
, target :: PGType f
, context :: Column f Char
}
deriving stock (Generic)
deriving anyclass (Rel8able)
deriving stock instance Show (Cast Result)
fetchTables :: Connection -> IO (Either QueryError [PGTable Result])
fetchTables c = do
flip run c $ statement () $ Rel8.run $ select do
PGClass{ oid = tableOid, relname } <- orderBy (relname >$< asc) do
each pgclass
>>= filter ((lit RTable ==.) . relkind)
columns <- many do
attribute@PGAttribute{ atttypid } <-
each pgattribute
>>= filter ((tableOid ==.) . attrelid)
>>= filter ((>. 0) . attnum)
typ <-
each pgtype
>>= filter (\PGType{ oid = typoid } -> atttypid ==. typoid)
return Attribute{ attribute, typ }
return PGTable
{ name = relname
, ..
}
fetchCasts :: Connection -> IO (Either QueryError [Cast Result])
fetchCasts c = do
flip run c $ statement () $ Rel8.run $ select do
PGCast {castsource, casttarget, castcontext} <- each pgcast
src <- each pgtype >>= filter (\PGType { oid = typoid } -> typoid ==. castsource)
tgt <- each pgtype >>= filter (\PGType { oid = typoid } -> typoid ==. casttarget)
return Cast { source = src, target = tgt, context = castcontext }
data CheckEnv = CheckEnv
{ ctx :: [String]
, schemaMap :: M.Map String [Attribute Result] -- map of schemas to attributes
, casts :: [(String, String)] -- list of implicit casts
} deriving (Show)
data Warning = Warning
{ ctx :: [String]
, warning :: String
} deriving (Show)
data Error = Error
{ ctx :: [String]
, error :: String
} deriving (Show)
data CheckResults = CheckResults
{ warnings :: [Warning]
, errors :: [Error]
} deriving (Show)
instance Semigroup CheckResults where
cr1 <> cr2 = CheckResults { warnings = cr1.warnings <> cr2.warnings, errors = cr1.errors <> cr2.errors }
instance Monoid CheckResults where
mempty = CheckResults [] []
addWarning :: (MonadAccum CheckResults m, MonadReader CheckEnv m) => String -> m ()
addWarning warn = do
env <- ask
add $ mempty { warnings = [Warning env.ctx warn] }
addError :: (MonadAccum CheckResults m, MonadReader CheckEnv m) => String -> m ()
addError err = do
env <- ask
add $ mempty { errors = [Error env.ctx err] }
withCtx :: MonadReader CheckEnv m => String -> m a -> m a
withCtx str = local @CheckEnv (\env -> env {ctx = env.ctx <> [str]})
nulled :: forall t. Nullable t => Bool
nulled = nullableToBool $ Null.nullable @t
nullableToBool :: Nullity a -> Bool
nullableToBool Null = True
nullableToBool NotNull = False
attrsToMap :: [Attribute Result] -> M.Map String (Attribute Result)
attrsToMap = foldMap (\attr -> M.singleton (T.unpack $ attr.attribute.attname) attr)
data TypeInfo = TypeInfo
{ label :: String
, isNull :: Bool
, typeName :: QualifiedName
} deriving (Show, Eq)
schemaToTypeMap :: forall k. Rel8able k => k Name -> M.Map String TypeInfo
schemaToTypeMap cols = M.fromList . uncurry zip . getConst $
htabulateA @(Columns (k Name)) $ \field ->
case (hfield hspecs field, hfield (toColumns cols) field) of
(Spec {..}, Name name) -> Const ([name], [
TypeInfo { label = head labels
, isNull = nullableToBool nullity
, typeName = info.typeName.name}])
-- implicit casts are ok as long as they're bidirectional
checkTypeEquality :: (MonadAccum CheckResults m, MonadReader CheckEnv m) => Attribute Result -> TypeInfo -> m ()
checkTypeEquality attr ty = ask @CheckEnv >>= go
where
go env
| attrTyName == tyTyName
= pure ()
| canConvertFromDb && canConvertFromHs
= pure ()
| canConvertFromDb
= addWarning $ "Can't convert from hs type " ++ tyTyName ++ " to db type " ++ attrTyName
| canConvertFromHs
= addWarning $ "Can't convert from db type " ++ attrTyName ++ " to hs type " ++ tyTyName
| otherwise
= addError $ "No conversions between db type " ++ attrTyName ++ " and hs type " ++ tyTyName
where
canConvertFromDb = (attrTyName, tyTyName) `elem` env.casts
canConvertFromHs = (tyTyName, attrTyName) `elem` env.casts
attrTyName = T.unpack attr.typ.typname
tyTyName = ty.typeName.name
checkNullity :: (MonadAccum CheckResults m, MonadReader CheckEnv m) => Attribute Result -> TypeInfo -> m ()
checkNullity attr ty
| attrNotNull && ty.isNull
= addWarning $ "db column " ++ attrName ++ " not null but hs type is nullable"
| not attrNotNull && not ty.isNull
= addError $ "db column " ++ attrName ++ " nullable but hs type not nullable"
| otherwise = pure ()
where
attrNotNull = attr.attribute.attnotnull
attrName = T.unpack attr.attribute.attname
checkTypes :: (MonadAccum CheckResults m, MonadReader CheckEnv m) => M.Map String (Attribute Result) -> M.Map String TypeInfo -> m ()
checkTypes attrMap typeMap = do
forM_ (M.assocs typeMap) $ \(key, ty) -> case M.lookup key attrMap of
Just attr -> withCtx key $ checkTypeEquality attr ty >> checkNullity attr ty
Nothing -> addError $ "Entry " ++ key ++ " not present in db"
forM_ (M.keys $ M.filter (\attr -> attr.attribute.attnotnull) attrMap) $
\key -> case M.lookup key typeMap of
Just _ -> pure ()
Nothing -> addError $ "Entry " ++ key ++ " not null but not present in hs ty"
-- a schema is valid if:
-- 1. for every existing field, the types match
-- 2. all non-nullable columns are present in the hs type
-- 3. no nonexistent columns are present in the hs type
-- 4. nullity of fields matches
verifySchema :: (Rel8able k, MonadAccum CheckResults m, MonadReader CheckEnv m) => TableSchema (k Name) -> m ()
verifySchema schema = asks (\env -> M.lookup schema.name.name env.schemaMap) >>= go
where
typeMap = schemaToTypeMap schema.columns
go Nothing = addError $ "Table " ++ schema.name.name ++ " not found"
go (Just attrs) = do
withCtx schema.name.name $ checkTypes attrMap typeMap
where
attrMap = attrsToMap attrs
fetchCheckEnv :: Connection -> IO CheckEnv
fetchCheckEnv c = do
tbls <- fetchTables c >>= either (fail . show) pure
casts <- fetchCasts c >>= either (fail . show) pure
let tblMap = foldMap (\PGTable {..} -> M.singleton (T.unpack name) columns) tbls
let castMap = map (\Cast {..} -> (T.unpack source.typname, T.unpack target.typname)) $ L.filter (\Cast {context} -> context == 'i') casts
return $ CheckEnv [] tblMap castMap
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment