Skip to content

Instantly share code, notes, and snippets.

@nurpax
Created August 16, 2012 06:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nurpax/3367303 to your computer and use it in GitHub Desktop.
Save nurpax/3367303 to your computer and use it in GitHub Desktop.
commit 9d18e815c739ff6065349d5f10adf458a844406b
Author: Janne Hellsten <jjhellst@gmail.com>
Date: Thu Aug 16 09:17:58 2012 +0300
Throw an error if # of bind params doesn't the query (#2)
Fail with an error rather than let sqlite convert missing
parameters to NULLs.
diff --git a/Database/SQLite3.hsc b/Database/SQLite3.hsc
index 2026392..4220438 100644
--- a/Database/SQLite3.hsc
+++ b/Database/SQLite3.hsc
@@ -30,6 +30,7 @@ module Database.SQLite3 (
)
where
+import Control.Monad
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import qualified Data.Text as T
@@ -393,6 +394,10 @@ bindText statement parameterIndex text = do
bind :: Statement -> [SQLData] -> IO ()
bind statement sqlData = do
+ nParams <- bindParameterCount statement
+ when (nParams /= length sqlData) $
+ fail ("mismatched parameter count for bind. Prepared statement "++
+ "needs "++ show nParams ++ ", " ++ show (length sqlData) ++" given")
mapM (\(parameterIndex, datum) -> do
case datum of
SQLInteger int64 -> bindInt64 statement parameterIndex int64
diff --git a/test/Main.hs b/test/Main.hs
index 83384da..d956bb6 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,4 +1,5 @@
+import Prelude
import Control.Exception (bracket)
import Control.Monad (when)
import System.Exit (exitFailure)
@@ -18,10 +19,17 @@ data TestEnv =
tests :: [TestEnv -> Test]
tests =
[ TestLabel "Simple" . testSimplest
+ , TestLabel "Params" . testBind
, TestLabel "Params" . testBindParamCounts
, TestLabel "Params" . testBindParamName
+ , TestLabel "Params" . testBindErrorValidation
]
+assertBindErrorCaught :: IO a -> Assertion
+assertBindErrorCaught action = do
+ catch (action >> return False) (\_ -> return True) >>=
+ assertBool "assertExceptionCaught"
+
-- Simplest SELECT
testSimplest :: TestEnv -> Test
testSimplest TestEnv{..} = TestCase $ do
@@ -32,6 +40,27 @@ testSimplest TestEnv{..} = TestCase $ do
finalize stmt
assertEqual "1+1" (SQLInteger 2) res
+testBind :: TestEnv -> Test
+testBind TestEnv{..} = TestCase $ do
+ bracket (prepare conn "SELECT ?") finalize testBind1
+ bracket (prepare conn "SELECT ?+?") finalize testBind2
+ where
+ testBind1 stmt = do
+ let params = [SQLInteger 3]
+ bind stmt params
+ Row <- step stmt
+ res <- columns stmt
+ Done <- step stmt
+ assertEqual "single param" params res
+
+ testBind2 stmt = do
+ let params = [SQLInteger 1, SQLInteger 1]
+ bind stmt params
+ Row <- step stmt
+ res <- columns stmt
+ Done <- step stmt
+ assertEqual "two params param" [SQLInteger 2] res
+
-- Test bindParameterCount
testBindParamCounts :: TestEnv -> Test
testBindParamCounts TestEnv{..} = TestCase $ do
@@ -58,6 +87,16 @@ testBindParamName TestEnv{..} = TestCase $ do
name <- bindParameterName stmt ndx
assertEqual "name match" expecting name) $ zip [1..] names
+testBindErrorValidation :: TestEnv -> Test
+testBindErrorValidation TestEnv{..} = TestCase $ do
+ bracket (prepare conn "SELECT ?") finalize (\stmt -> assertBindErrorCaught (testException1 stmt))
+ bracket (prepare conn "SELECT ?") finalize (\stmt -> assertBindErrorCaught (testException2 stmt))
+ where
+ -- Invalid use, one param in q string, none given
+ testException1 stmt = bind stmt []
+ -- Invalid use, one param in q string, 2 given
+ testException2 stmt = bind stmt [SQLInteger 1, SQLInteger 2]
+
-- | Action for connecting to the database that will be used for
-- testing.
--
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment