diff --git a/Database/MySQL/Simple.hs b/Database/MySQL/Simple.hs index d2e7ad4..ceb3c30 100644 --- a/Database/MySQL/Simple.hs +++ b/Database/MySQL/Simple.hs @@ -80,6 +80,7 @@ module Database.MySQL.Simple -- * Helper functions , formatMany , formatQuery + , splitQuery ) where import Blaze.ByteString.Builder (Builder, fromByteString, toByteString) @@ -88,6 +89,7 @@ import Control.Applicative ((<$>), pure) import Control.Exception (Exception, bracket, onException, throw, throwIO) import Control.Monad.Fix (fix) import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as BS import Data.Int (Int64) import Data.List (intersperse) import Data.Monoid (mappend, mconcat) @@ -169,17 +171,46 @@ formatMany conn q@(Query template) qs = do [caseless] buildQuery :: Connection -> Query -> ByteString -> [Action] -> IO Builder -buildQuery conn q template xs = zipParams (split template) <$> mapM sub xs +buildQuery conn q template xs = zipParams queryFragments <$> mapM sub xs where sub (Plain b) = pure b sub (Escape s) = (inQuotes . fromByteString) <$> Base.escape conn s sub (Many ys) = mconcat <$> mapM sub ys - split s = fromByteString h : if B.null t then [] else split (B.tail t) - where (h,t) = B.break (=='?') s zipParams (t:ts) (p:ps) = t `mappend` p `mappend` zipParams ts ps zipParams [t] [] = t - zipParams _ _ = fmtError (show (B.count '?' template) ++ + zipParams _ _ = fmtError (show fragmentCount ++ " '?' characters, but " ++ show (length xs) ++ " parameters") q xs + fragmentCount = length queryFragments - 1 + queryFragments = splitQuery template + +-- | Split a query into fragments separated by @?@ characters. Does not +-- break a fragment if the question mark is in a string literal. +splitQuery :: ByteString -> [Builder] +splitQuery s = + reverse $ fmap (fromByteString . BS.pack . reverse) $ + begin [] (BS.unpack s) + where + begin = normal [] + + normal ret acc [] = + acc : ret + normal ret acc (c : cs) = + case c of + '?' -> + normal (acc : ret) [] cs + '\'' -> + quotes ret (c : acc) cs + _ -> + normal ret (c : acc) cs + + quotes ret acc [] = + acc : ret + quotes ret acc (c : cs) = + case c of + '\'' -> + normal ret (c : acc) cs + _ -> + quotes ret (c : acc) cs -- | Execute an @INSERT@, @UPDATE@, or other SQL query that is not -- expected to return results. @@ -373,7 +404,7 @@ fmtError msg q xs = throw FormatError { -- facility to address both ease of use and security. -- $querytype --- +-- -- A 'Query' is a @newtype@-wrapped 'ByteString'. It intentionally -- exposes a tiny API that is not compatible with the 'ByteString' -- API; this makes it difficult to construct queries from fragments of diff --git a/mysql-simple.cabal b/mysql-simple.cabal index cbd99bf..40946b1 100644 --- a/mysql-simple.cabal +++ b/mysql-simple.cabal @@ -84,6 +84,7 @@ test-suite test ghc-options: -Wall default-language: Haskell2010 build-depends: base >= 4 && < 5 + , bytestring , blaze-builder , hspec , mysql-simple diff --git a/test/main.hs b/test/main.hs index c198ceb..2bfbfd2 100644 --- a/test/main.hs +++ b/test/main.hs @@ -1,5 +1,9 @@ {-# LANGUAGE CPP, OverloadedStrings #-} + +{-# options_ghc -fno-warn-orphans #-} + +import Data.ByteString.Builder as BS import Control.Applicative ((<|>)) import Control.Exception (bracket) import Data.Text (Text) @@ -31,8 +35,8 @@ main = do ci <- isCI bracket (connect $ testConn ci) close $ \conn -> hspec $ do - unitSpec - integrationSpec conn + describe "Database.MySQL.Simple.unitSpec" unitSpec + describe "Database.MySQL.Simple.integrationSpec" $ integrationSpec conn unitSpec :: Spec unitSpec = do @@ -53,9 +57,51 @@ unitSpec = do Many [Plain _, Escape "foo", Plain _, Escape "bar", Plain _] -> pure () _ -> expectationFailure "expected a Many with specific contents" + describe "splitQuery" $ do + it "works for a single question mark" $ do + splitQuery "select * from foo where name = ?" + `shouldBe` + ["select * from foo where name = ", ""] + it "works with a question mark in a string literal" $ do + splitQuery "select 'hello?'" + `shouldBe` + ["select 'hello?'"] + it "works with many question marks" $ do + splitQuery "select ? + ? + what from foo where bar = ?" + `shouldBe` + ["select ", " + ", " + what from foo where bar = ", ""] + +instance Show BS.Builder where + show = show . BS.toLazyByteString + +instance Eq BS.Builder where + a == b = BS.toLazyByteString a == BS.toLazyByteString b + integrationSpec :: Connection -> Spec integrationSpec conn = do - describe "the library" $ do + describe "query_" $ do it "can connect to a database" $ do result <- query_ conn "select 1 + 1" result `shouldBe` [Only (2::Int)] + it "can have question marks in string literals" $ do + result <- query_ conn "select 'hello?'" + result `shouldBe` [Only ("hello?" :: Text)] + describe "query" $ do + it "can have question marks in string literals" $ do + result <- query conn "select 'hello?'" () + result `shouldBe` [Only ("hello?" :: Text)] + describe "with too many query params" $ do + it "should have the right message" $ do + (query conn "select 'hello?'" (Only ['a']) :: IO [Only Text]) + `shouldThrow` + (\e -> fmtMessage e == "0 '?' characters, but 1 parameters") + describe "with too few query params" $ do + it "should have the right message" $ do + (query conn "select 'hello?' = ?" () :: IO [Only Text]) + `shouldThrow` + (\e -> fmtMessage e == "1 '?' characters, but 0 parameters") + describe "formatQuery" $ do + it "should not blow up on a question mark in string literal" $ do + formatQuery conn "select 'hello?'" () + `shouldReturn` + "select 'hello?'"