Skip to content

Commit

Permalink
Question marks in string literals cause substitution errors (#54)
Browse files Browse the repository at this point in the history
* Failing test

* Extract logic, more focused test

* Fixed

* more tests, fixes

* warns, imporst

* format

* ok
  • Loading branch information
parsonsmatt authored Mar 19, 2022
1 parent 41b947b commit 1c8b693
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 8 deletions.
41 changes: 36 additions & 5 deletions Database/MySQL/Simple.hs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ module Database.MySQL.Simple
-- * Helper functions
, formatMany
, formatQuery
, splitQuery
) where

import Blaze.ByteString.Builder (Builder, fromByteString, toByteString)
Expand All @@ -88,6 +89,7 @@ import Control.Applicative ((<$>), pure)
import Control.Exception (Exception, bracket, onException, throw, throwIO)
import Control.Monad.Fix (fix)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as BS
import Data.Int (Int64)
import Data.List (intersperse)
import Data.Monoid (mappend, mconcat)
Expand Down Expand Up @@ -169,17 +171,46 @@ formatMany conn q@(Query template) qs = do
[caseless]

buildQuery :: Connection -> Query -> ByteString -> [Action] -> IO Builder
buildQuery conn q template xs = zipParams (split template) <$> mapM sub xs
buildQuery conn q template xs = zipParams queryFragments <$> mapM sub xs
where sub (Plain b) = pure b
sub (Escape s) = (inQuotes . fromByteString) <$> Base.escape conn s
sub (Many ys) = mconcat <$> mapM sub ys
split s = fromByteString h : if B.null t then [] else split (B.tail t)
where (h,t) = B.break (=='?') s
zipParams (t:ts) (p:ps) = t `mappend` p `mappend` zipParams ts ps
zipParams [t] [] = t
zipParams _ _ = fmtError (show (B.count '?' template) ++
zipParams _ _ = fmtError (show fragmentCount ++
" '?' characters, but " ++
show (length xs) ++ " parameters") q xs
fragmentCount = length queryFragments - 1
queryFragments = splitQuery template

-- | Split a query into fragments separated by @?@ characters. Does not
-- break a fragment if the question mark is in a string literal.
splitQuery :: ByteString -> [Builder]
splitQuery s =
reverse $ fmap (fromByteString . BS.pack . reverse) $
begin [] (BS.unpack s)
where
begin = normal []

normal ret acc [] =
acc : ret
normal ret acc (c : cs) =
case c of
'?' ->
normal (acc : ret) [] cs
'\'' ->
quotes ret (c : acc) cs
_ ->
normal ret (c : acc) cs

quotes ret acc [] =
acc : ret
quotes ret acc (c : cs) =
case c of
'\'' ->
normal ret (c : acc) cs
_ ->
quotes ret (c : acc) cs

-- | Execute an @INSERT@, @UPDATE@, or other SQL query that is not
-- expected to return results.
Expand Down Expand Up @@ -373,7 +404,7 @@ fmtError msg q xs = throw FormatError {
-- facility to address both ease of use and security.

-- $querytype
--
--
-- A 'Query' is a @newtype@-wrapped 'ByteString'. It intentionally
-- exposes a tiny API that is not compatible with the 'ByteString'
-- API; this makes it difficult to construct queries from fragments of
Expand Down
1 change: 1 addition & 0 deletions mysql-simple.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ test-suite test
ghc-options: -Wall
default-language: Haskell2010
build-depends: base >= 4 && < 5
, bytestring
, blaze-builder
, hspec
, mysql-simple
Expand Down
52 changes: 49 additions & 3 deletions test/main.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
{-# LANGUAGE CPP, OverloadedStrings #-}


{-# options_ghc -fno-warn-orphans #-}

import Data.ByteString.Builder as BS
import Control.Applicative ((<|>))
import Control.Exception (bracket)
import Data.Text (Text)
Expand Down Expand Up @@ -31,8 +35,8 @@ main = do
ci <- isCI
bracket (connect $ testConn ci) close $ \conn ->
hspec $ do
unitSpec
integrationSpec conn
describe "Database.MySQL.Simple.unitSpec" unitSpec
describe "Database.MySQL.Simple.integrationSpec" $ integrationSpec conn

unitSpec :: Spec
unitSpec = do
Expand All @@ -53,9 +57,51 @@ unitSpec = do
Many [Plain _, Escape "foo", Plain _, Escape "bar", Plain _] -> pure ()
_ -> expectationFailure "expected a Many with specific contents"

describe "splitQuery" $ do
it "works for a single question mark" $ do
splitQuery "select * from foo where name = ?"
`shouldBe`
["select * from foo where name = ", ""]
it "works with a question mark in a string literal" $ do
splitQuery "select 'hello?'"
`shouldBe`
["select 'hello?'"]
it "works with many question marks" $ do
splitQuery "select ? + ? + what from foo where bar = ?"
`shouldBe`
["select ", " + ", " + what from foo where bar = ", ""]

instance Show BS.Builder where
show = show . BS.toLazyByteString

instance Eq BS.Builder where
a == b = BS.toLazyByteString a == BS.toLazyByteString b

integrationSpec :: Connection -> Spec
integrationSpec conn = do
describe "the library" $ do
describe "query_" $ do
it "can connect to a database" $ do
result <- query_ conn "select 1 + 1"
result `shouldBe` [Only (2::Int)]
it "can have question marks in string literals" $ do
result <- query_ conn "select 'hello?'"
result `shouldBe` [Only ("hello?" :: Text)]
describe "query" $ do
it "can have question marks in string literals" $ do
result <- query conn "select 'hello?'" ()
result `shouldBe` [Only ("hello?" :: Text)]
describe "with too many query params" $ do
it "should have the right message" $ do
(query conn "select 'hello?'" (Only ['a']) :: IO [Only Text])
`shouldThrow`
(\e -> fmtMessage e == "0 '?' characters, but 1 parameters")
describe "with too few query params" $ do
it "should have the right message" $ do
(query conn "select 'hello?' = ?" () :: IO [Only Text])
`shouldThrow`
(\e -> fmtMessage e == "1 '?' characters, but 0 parameters")
describe "formatQuery" $ do
it "should not blow up on a question mark in string literal" $ do
formatQuery conn "select 'hello?'" ()
`shouldReturn`
"select 'hello?'"

0 comments on commit 1c8b693

Please sign in to comment.