Created
May 1, 2024 13:05
-
-
Save KaneTW/c223d4da042f5d66e41ee27a46bcefc5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# language ApplicativeDo #-} | |
{-# language LambdaCase #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE TypeInType #-} | |
{-# LANGUAGE AllowAmbiguousTypes #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE InstanceSigs #-} | |
-- not sure if this is all extensions or all these extensions are needed, just pasting from the file | |
module VerifySchema where | |
import Prelude hiding ( filter ) | |
import qualified Data.List as L | |
import Rel8 hiding (run) | |
import qualified Rel8 | |
import Hasql.Connection | |
import Hasql.Session | |
import Data.Functor.Contravariant ( (>$<) ) | |
import Data.Int ( Int64 ) | |
import Data.Text ( Text ) | |
import qualified Data.Text as T | |
import GHC.Generics | |
import qualified Data.List.NonEmpty as NonEmpty | |
import qualified Data.Map as M | |
import Rel8.Schema.Null hiding (nullable) | |
import qualified Rel8.Schema.Null as Null | |
import Rel8.Schema.Name ( Name(Name) ) | |
import Rel8.Schema.Spec | |
import Data.Functor.Const | |
import Rel8.Schema.HTable | |
import Control.Monad | |
import Control.Monad.Accum | |
import Control.Monad.Reader | |
data Relkind = RTable | |
deriving stock (Show) | |
deriving anyclass (DBEq) | |
instance DBType Relkind where | |
typeInformation = parseTypeInformation parser printer typeInformation | |
where | |
parser = \case | |
"r" -> pure RTable | |
(x :: Text) -> Left $ "Unknown relkind: " ++ show x | |
printer = \case | |
RTable -> "r" | |
newtype Oid = Oid Int64 | |
deriving newtype (DBType, DBEq, Show) | |
data PGClass f = PGClass | |
{ oid :: Column f Oid | |
, relname :: Column f Text | |
, relkind :: Column f Relkind | |
, relnamespace :: Column f Oid | |
} | |
deriving stock (Generic) | |
deriving anyclass (Rel8able) | |
deriving stock instance Show (PGClass Result) | |
pgclass :: TableSchema (PGClass Name) | |
pgclass = TableSchema | |
{ name = QualifiedName "pg_class" (Just "pg_catalog") | |
, columns = namesFromLabelsWith NonEmpty.last | |
} | |
data PGAttribute f = PGAttribute | |
{ attrelid :: Column f Oid | |
, attname :: Column f Text | |
, atttypid :: Column f Oid | |
, attnum :: Column f Int64 | |
, attnotnull :: Column f Bool | |
} | |
deriving stock (Generic) | |
deriving anyclass (Rel8able) | |
deriving stock instance Show (PGAttribute Result) | |
pgattribute :: TableSchema (PGAttribute Name) | |
pgattribute = TableSchema | |
{ name = QualifiedName "pg_attribute" (Just "pg_catalog") | |
, columns = namesFromLabelsWith NonEmpty.last | |
} | |
data PGType f = PGType | |
{ oid :: Column f Oid | |
, typname :: Column f Text | |
} | |
deriving stock (Generic) | |
deriving anyclass (Rel8able) | |
deriving stock instance Show (PGType Result) | |
pgtype :: TableSchema (PGType Name) | |
pgtype = TableSchema | |
{ name = QualifiedName "pg_type" (Just "pg_catalog") | |
, columns = namesFromLabelsWith NonEmpty.last | |
} | |
data PGNamespace f = PGNamespace | |
{ oid :: Column f Oid | |
, nspname :: Column f Text | |
} | |
deriving stock (Generic) | |
deriving anyclass (Rel8able) | |
deriving stock instance Show (PGNamespace Result) | |
pgnamespace :: TableSchema (PGNamespace Name) | |
pgnamespace = TableSchema | |
{ name = QualifiedName "pg_namespace" (Just "pg_catalog") | |
, columns = namesFromLabelsWith NonEmpty.last | |
} | |
data PGCast f = PGCast | |
{ oid :: Column f Oid | |
, castsource :: Column f Oid | |
, casttarget :: Column f Oid | |
, castfunc :: Column f Oid | |
, castcontext :: Column f Char | |
, castmethod :: Column f Char | |
} | |
deriving stock (Generic) | |
deriving anyclass (Rel8able) | |
deriving stock instance Show (PGCast Result) | |
pgcast :: TableSchema (PGCast Name) | |
pgcast = TableSchema | |
{ name = QualifiedName "pg_cast" (Just "pg_catalog") | |
, columns = namesFromLabelsWith NonEmpty.last | |
} | |
data PGTable f = PGTable | |
{ name :: Column f Text | |
, columns :: HList f (Attribute f) | |
} | |
deriving stock (Generic) | |
deriving anyclass (Rel8able) | |
deriving stock instance Show (PGTable Result) | |
data Attribute f = Attribute | |
{ attribute :: PGAttribute f | |
, typ :: PGType f | |
} | |
deriving stock (Generic) | |
deriving anyclass (Rel8able) | |
deriving stock instance Show (Attribute Result) | |
data Cast f = Cast | |
{ source :: PGType f | |
, target :: PGType f | |
, context :: Column f Char | |
} | |
deriving stock (Generic) | |
deriving anyclass (Rel8able) | |
deriving stock instance Show (Cast Result) | |
fetchTables :: Connection -> IO (Either QueryError [PGTable Result]) | |
fetchTables c = do | |
flip run c $ statement () $ Rel8.run $ select do | |
PGClass{ oid = tableOid, relname } <- orderBy (relname >$< asc) do | |
each pgclass | |
>>= filter ((lit RTable ==.) . relkind) | |
columns <- many do | |
attribute@PGAttribute{ atttypid } <- | |
each pgattribute | |
>>= filter ((tableOid ==.) . attrelid) | |
>>= filter ((>. 0) . attnum) | |
typ <- | |
each pgtype | |
>>= filter (\PGType{ oid = typoid } -> atttypid ==. typoid) | |
return Attribute{ attribute, typ } | |
return PGTable | |
{ name = relname | |
, .. | |
} | |
fetchCasts :: Connection -> IO (Either QueryError [Cast Result]) | |
fetchCasts c = do | |
flip run c $ statement () $ Rel8.run $ select do | |
PGCast {castsource, casttarget, castcontext} <- each pgcast | |
src <- each pgtype >>= filter (\PGType { oid = typoid } -> typoid ==. castsource) | |
tgt <- each pgtype >>= filter (\PGType { oid = typoid } -> typoid ==. casttarget) | |
return Cast { source = src, target = tgt, context = castcontext } | |
data CheckEnv = CheckEnv | |
{ ctx :: [String] | |
, schemaMap :: M.Map String [Attribute Result] -- map of schemas to attributes | |
, casts :: [(String, String)] -- list of implicit casts | |
} deriving (Show) | |
data Warning = Warning | |
{ ctx :: [String] | |
, warning :: String | |
} deriving (Show) | |
data Error = Error | |
{ ctx :: [String] | |
, error :: String | |
} deriving (Show) | |
data CheckResults = CheckResults | |
{ warnings :: [Warning] | |
, errors :: [Error] | |
} deriving (Show) | |
instance Semigroup CheckResults where | |
cr1 <> cr2 = CheckResults { warnings = cr1.warnings <> cr2.warnings, errors = cr1.errors <> cr2.errors } | |
instance Monoid CheckResults where | |
mempty = CheckResults [] [] | |
addWarning :: (MonadAccum CheckResults m, MonadReader CheckEnv m) => String -> m () | |
addWarning warn = do | |
env <- ask | |
add $ mempty { warnings = [Warning env.ctx warn] } | |
addError :: (MonadAccum CheckResults m, MonadReader CheckEnv m) => String -> m () | |
addError err = do | |
env <- ask | |
add $ mempty { errors = [Error env.ctx err] } | |
withCtx :: MonadReader CheckEnv m => String -> m a -> m a | |
withCtx str = local @CheckEnv (\env -> env {ctx = env.ctx <> [str]}) | |
nulled :: forall t. Nullable t => Bool | |
nulled = nullableToBool $ Null.nullable @t | |
nullableToBool :: Nullity a -> Bool | |
nullableToBool Null = True | |
nullableToBool NotNull = False | |
attrsToMap :: [Attribute Result] -> M.Map String (Attribute Result) | |
attrsToMap = foldMap (\attr -> M.singleton (T.unpack $ attr.attribute.attname) attr) | |
data TypeInfo = TypeInfo | |
{ label :: String | |
, isNull :: Bool | |
, typeName :: QualifiedName | |
} deriving (Show, Eq) | |
schemaToTypeMap :: forall k. Rel8able k => k Name -> M.Map String TypeInfo | |
schemaToTypeMap cols = M.fromList . uncurry zip . getConst $ | |
htabulateA @(Columns (k Name)) $ \field -> | |
case (hfield hspecs field, hfield (toColumns cols) field) of | |
(Spec {..}, Name name) -> Const ([name], [ | |
TypeInfo { label = head labels | |
, isNull = nullableToBool nullity | |
, typeName = info.typeName.name}]) | |
-- implicit casts are ok as long as they're bidirectional | |
checkTypeEquality :: (MonadAccum CheckResults m, MonadReader CheckEnv m) => Attribute Result -> TypeInfo -> m () | |
checkTypeEquality attr ty = ask @CheckEnv >>= go | |
where | |
go env | |
| attrTyName == tyTyName | |
= pure () | |
| canConvertFromDb && canConvertFromHs | |
= pure () | |
| canConvertFromDb | |
= addWarning $ "Can't convert from hs type " ++ tyTyName ++ " to db type " ++ attrTyName | |
| canConvertFromHs | |
= addWarning $ "Can't convert from db type " ++ attrTyName ++ " to hs type " ++ tyTyName | |
| otherwise | |
= addError $ "No conversions between db type " ++ attrTyName ++ " and hs type " ++ tyTyName | |
where | |
canConvertFromDb = (attrTyName, tyTyName) `elem` env.casts | |
canConvertFromHs = (tyTyName, attrTyName) `elem` env.casts | |
attrTyName = T.unpack attr.typ.typname | |
tyTyName = ty.typeName.name | |
checkNullity :: (MonadAccum CheckResults m, MonadReader CheckEnv m) => Attribute Result -> TypeInfo -> m () | |
checkNullity attr ty | |
| attrNotNull && ty.isNull | |
= addWarning $ "db column " ++ attrName ++ " not null but hs type is nullable" | |
| not attrNotNull && not ty.isNull | |
= addError $ "db column " ++ attrName ++ " nullable but hs type not nullable" | |
| otherwise = pure () | |
where | |
attrNotNull = attr.attribute.attnotnull | |
attrName = T.unpack attr.attribute.attname | |
checkTypes :: (MonadAccum CheckResults m, MonadReader CheckEnv m) => M.Map String (Attribute Result) -> M.Map String TypeInfo -> m () | |
checkTypes attrMap typeMap = do | |
forM_ (M.assocs typeMap) $ \(key, ty) -> case M.lookup key attrMap of | |
Just attr -> withCtx key $ checkTypeEquality attr ty >> checkNullity attr ty | |
Nothing -> addError $ "Entry " ++ key ++ " not present in db" | |
forM_ (M.keys $ M.filter (\attr -> attr.attribute.attnotnull) attrMap) $ | |
\key -> case M.lookup key typeMap of | |
Just _ -> pure () | |
Nothing -> addError $ "Entry " ++ key ++ " not null but not present in hs ty" | |
-- a schema is valid if: | |
-- 1. for every existing field, the types match | |
-- 2. all non-nullable columns are present in the hs type | |
-- 3. no nonexistent columns are present in the hs type | |
-- 4. nullity of fields matches | |
verifySchema :: (Rel8able k, MonadAccum CheckResults m, MonadReader CheckEnv m) => TableSchema (k Name) -> m () | |
verifySchema schema = asks (\env -> M.lookup schema.name.name env.schemaMap) >>= go | |
where | |
typeMap = schemaToTypeMap schema.columns | |
go Nothing = addError $ "Table " ++ schema.name.name ++ " not found" | |
go (Just attrs) = do | |
withCtx schema.name.name $ checkTypes attrMap typeMap | |
where | |
attrMap = attrsToMap attrs | |
fetchCheckEnv :: Connection -> IO CheckEnv | |
fetchCheckEnv c = do | |
tbls <- fetchTables c >>= either (fail . show) pure | |
casts <- fetchCasts c >>= either (fail . show) pure | |
let tblMap = foldMap (\PGTable {..} -> M.singleton (T.unpack name) columns) tbls | |
let castMap = map (\Cast {..} -> (T.unpack source.typname, T.unpack target.typname)) $ L.filter (\Cast {context} -> context == 'i') casts | |
return $ CheckEnv [] tblMap castMap | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment