Skip to content

Commit 8566160

Browse files
Refactor Driver directory
1 parent 1b89885 commit 8566160

File tree

5 files changed

+71
-49
lines changed

5 files changed

+71
-49
lines changed

postgres-wire.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ library
5858
BangPatterns
5959
OverloadedStrings
6060
GeneralizedNewtypeDeriving
61+
LambdaCase
6162
cc-options: -O2 -Wall
6263

6364
test-suite postgres-wire-test-connection

src/Database/PostgreSQL/Driver/Connection.hs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ import Database.PostgreSQL.Driver.RawConnection
5656
-- | Public
5757
-- Connection parametrized by message type in chan.
5858
data AbsConnection mt = AbsConnection
59-
{ connRawConnection :: RawConnection
60-
, connReceiverThread :: Weak ThreadId
61-
, connStatementStorage :: StatementStorage
62-
, connParameters :: ConnectionParameters
63-
, connOutChan :: TQueue (Either ReceiverException mt)
59+
{ connRawConnection :: !RawConnection
60+
, connReceiverThread :: !(Weak ThreadId)
61+
, connStatementStorage :: !StatementStorage
62+
, connParameters :: !ConnectionParameters
63+
, connOutChan :: !(TQueue (Either ReceiverException mt))
6464
}
6565

6666
type Connection = AbsConnection DataMessage
@@ -122,15 +122,18 @@ connectCommon' settings msgFilter = connectWith settings $ \rawConn params ->
122122

123123
-- Low-level sending functions
124124

125+
{-# INLINE sendStartMessage #-}
125126
sendStartMessage :: RawConnection -> StartMessage -> IO ()
126127
sendStartMessage rawConn msg = void $
127128
rSend rawConn . runEncode $ encodeStartMessage msg
128129

129130
-- Only for testings and simple queries
131+
{-# INLINE sendMessage #-}
130132
sendMessage :: RawConnection -> ClientMessage -> IO ()
131133
sendMessage rawConn msg = void $
132134
rSend rawConn . runEncode $ encodeClientMessage msg
133135

136+
{-# INLINE sendEncode #-}
134137
sendEncode :: AbsConnection c -> Encode -> IO ()
135138
sendEncode conn = void . rSend (connRawConnection conn) . runEncode
136139

@@ -290,6 +293,11 @@ receiverThreadCommon rawConn chan msgFilter ntfHandler = go ""
290293
dispatchIfNotification (NotificationResponse ntf) handler = handler ntf
291294
dispatchIfNotification _ _ = pure ()
292295

296+
-- | Helper to read from queue.
297+
{-# INLINE writeChan #-}
298+
writeChan :: TQueue a -> a -> IO ()
299+
writeChan q = atomically . writeTQueue q
300+
293301
defaultNotificationHandler :: NotificationHandler
294302
defaultNotificationHandler = const $ pure ()
295303

@@ -332,7 +340,3 @@ defaultFilter msg = case msg of
332340
-- as result for `describe` message
333341
RowDescription{} -> True
334342

335-
-- | Helper to read from queue.
336-
writeChan :: TQueue a -> a -> IO ()
337-
writeChan q = atomically . writeTQueue q
338-

src/Database/PostgreSQL/Driver/Query.hs

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@ module Database.PostgreSQL.Driver.Query
1212
, collectUntilReadyForQuery
1313
) where
1414

15-
import Data.Foldable
16-
import Data.Monoid
17-
import Data.Bifunctor
18-
import qualified Data.Vector as V
19-
import qualified Data.ByteString as B
2015
import Control.Concurrent.STM.TQueue (TQueue, readTQueue )
21-
import Control.Concurrent.STM (atomically)
16+
import Control.Concurrent.STM (atomically)
17+
import Data.Foldable (fold)
18+
import Data.Monoid ((<>))
19+
import Data.ByteString (ByteString)
20+
import Data.Vector (Vector)
2221

2322
import Database.PostgreSQL.Protocol.Encoders
2423
import Database.PostgreSQL.Protocol.Store.Encode
@@ -31,26 +30,30 @@ import Database.PostgreSQL.Driver.StatementStorage
3130

3231
-- Public
3332
data Query = Query
34-
{ qStatement :: B.ByteString
35-
, qValues :: [(Oid, Maybe Encode)]
36-
, qParamsFormat :: Format
37-
, qResultFormat :: Format
38-
, qCachePolicy :: CachePolicy
33+
{ qStatement :: !ByteString
34+
, qValues :: ![(Oid, Maybe Encode)]
35+
, qParamsFormat :: !Format
36+
, qResultFormat :: !Format
37+
, qCachePolicy :: !CachePolicy
3938
} deriving (Show)
4039

4140
-- | Public
41+
{- INLINE sendBatchAndFlush #-}
4242
sendBatchAndFlush :: Connection -> [Query] -> IO ()
4343
sendBatchAndFlush = sendBatchEndBy Flush
4444

4545
-- | Public
46+
{-# INLINE sendBatchAndSync #-}
4647
sendBatchAndSync :: Connection -> [Query] -> IO ()
4748
sendBatchAndSync = sendBatchEndBy Sync
4849

4950
-- | Public
51+
{-# INLINE sendSync #-}
5052
sendSync :: Connection -> IO ()
5153
sendSync conn = sendEncode conn $ encodeClientMessage Sync
5254

5355
-- | Public
56+
{-# INLINABLE readNextData #-}
5457
readNextData :: Connection -> IO (Either Error DataRows)
5558
readNextData conn =
5659
readChan (connOutChan conn) >>=
@@ -62,6 +65,7 @@ readNextData conn =
6265
DataReady -> throwIncorrectUsage
6366
"Expected DataRow message, but got ReadyForQuery"
6467

68+
{-# INLINABLE waitReadyForQuery #-}
6569
waitReadyForQuery :: Connection -> IO (Either Error ())
6670
waitReadyForQuery conn =
6771
readChan (connOutChan conn) >>=
@@ -77,6 +81,7 @@ waitReadyForQuery conn =
7781
DataReady -> pure $ Right ()
7882

7983
-- Helper
84+
{-# INLINE sendBatchEndBy #-}
8085
sendBatchEndBy :: ClientMessage -> Connection -> [Query] -> IO ()
8186
sendBatchEndBy msg conn qs = do
8287
batch <- constructBatch conn qs
@@ -90,28 +95,27 @@ constructBatch conn = fmap fold . traverse constructSingle
9095
pname = PortalName ""
9196
constructSingle q = do
9297
let stmtSQL = StatementSQL $ qStatement q
93-
(sname, parseMessage) <- case qCachePolicy q of
94-
AlwaysCache -> do
95-
mName <- lookupStatement storage stmtSQL
96-
case mName of
97-
Nothing -> do
98-
newName <- storeStatement storage stmtSQL
99-
pure (newName, encodeClientMessage $
100-
Parse newName stmtSQL (fst <$> qValues q))
101-
Just name -> pure (name, mempty)
102-
NeverCache -> do
103-
let newName = defaultStatementName
104-
pure (newName, encodeClientMessage $
105-
Parse newName stmtSQL (fst <$> qValues q))
106-
let bindMessage = encodeClientMessage $
107-
Bind pname sname (qParamsFormat q) (snd <$> qValues q)
98+
(stmtName, needParse) <- case qCachePolicy q of
99+
AlwaysCache -> lookupStatement storage stmtSQL >>= \case
100+
Nothing -> do
101+
newName <- storeStatement storage stmtSQL
102+
pure (newName, True)
103+
Just name ->
104+
pure (name, False)
105+
NeverCache -> pure (defaultStatementName, True)
106+
let parseMessage = if needParse
107+
then encodeClientMessage $
108+
Parse stmtName stmtSQL (fst <$> qValues q)
109+
else mempty
110+
bindMessage = encodeClientMessage $
111+
Bind pname stmtName (qParamsFormat q) (snd <$> qValues q)
108112
(qResultFormat q)
109113
executeMessage = encodeClientMessage $
110114
Execute pname noLimitToReceive
111115
pure $ parseMessage <> bindMessage <> executeMessage
112116

113117
-- | Public
114-
sendSimpleQuery :: ConnectionCommon -> B.ByteString -> IO (Either Error ())
118+
sendSimpleQuery :: ConnectionCommon -> ByteString -> IO (Either Error ())
115119
sendSimpleQuery conn q = do
116120
sendMessage (connRawConnection conn) $ SimpleQuery (StatementSQL q)
117121
(checkErrors =<<) <$> collectUntilReadyForQuery conn
@@ -122,8 +126,8 @@ sendSimpleQuery conn q = do
122126
-- | Public
123127
describeStatement
124128
:: ConnectionCommon
125-
-> B.ByteString
126-
-> IO (Either Error (V.Vector Oid, V.Vector FieldDescription))
129+
-> ByteString
130+
-> IO (Either Error (Vector Oid, Vector FieldDescription))
127131
describeStatement conn stmt = do
128132
sendEncode conn $
129133
encodeClientMessage (Parse sname (StatementSQL stmt) [])
@@ -135,7 +139,7 @@ describeStatement conn stmt = do
135139
sname = StatementName ""
136140
parseMessages msgs = case msgs of
137141
[ParameterDescription params, NoData]
138-
-> pure $ Right (params, V.empty)
142+
-> pure $ Right (params, mempty)
139143
[ParameterDescription params, RowDescription fields]
140144
-> pure $ Right (params, fields)
141145
xs -> maybe
@@ -160,5 +164,6 @@ findFirstError [] = Nothing
160164
findFirstError (ErrorResponse desc : _) = Just desc
161165
findFirstError (_ : xs) = findFirstError xs
162166

167+
{-# INLINE readChan #-}
163168
readChan :: TQueue a -> IO a
164169
readChan = atomically . readTQueue

src/Database/PostgreSQL/Driver/StatementStorage.hs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1-
module Database.PostgreSQL.Driver.StatementStorage where
2-
3-
import qualified Data.HashTable.IO as H
4-
import qualified Data.ByteString as B
1+
module Database.PostgreSQL.Driver.StatementStorage
2+
( StatementStorage
3+
, CachePolicy(..)
4+
, newStatementStorage
5+
, lookupStatement
6+
, storeStatement
7+
, getCacheSize
8+
, defaultStatementName
9+
) where
10+
11+
import Data.Monoid ((<>))
12+
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
13+
import Data.Word (Word)
14+
15+
import Data.ByteString (ByteString)
516
import Data.ByteString.Char8 (pack)
6-
import Data.Word (Word)
7-
import Data.IORef
17+
import qualified Data.HashTable.IO as H
818

919
import Database.PostgreSQL.Protocol.Types
1020

@@ -21,16 +31,17 @@ data CachePolicy
2131
newStatementStorage :: IO StatementStorage
2232
newStatementStorage = StatementStorage <$> H.new <*> newIORef 0
2333

34+
{-# INLINE lookupStatement #-}
2435
lookupStatement :: StatementStorage -> StatementSQL -> IO (Maybe StatementName)
2536
lookupStatement (StatementStorage table _) = H.lookup table
2637

27-
-- TODO place right name
2838
-- TODO info about exceptions and mask
39+
{-# INLINE storeStatement #-}
2940
storeStatement :: StatementStorage -> StatementSQL -> IO StatementName
3041
storeStatement (StatementStorage table counter) stmt = do
3142
n <- readIORef counter
3243
writeIORef counter $ n + 1
33-
let name = StatementName . pack $ show n
44+
let name = StatementName . (statementPrefix <>) . pack $ show n
3445
H.insert table stmt name
3546
pure name
3647

@@ -40,3 +51,6 @@ getCacheSize (StatementStorage _ counter) = readIORef counter
4051
defaultStatementName :: StatementName
4152
defaultStatementName = StatementName ""
4253

54+
statementPrefix :: ByteString
55+
statementPrefix = "_pw_statement_"
56+

src/Database/PostgreSQL/Protocol/Codecs/Numeric.hs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
{-# language LambdaCase #-}
2-
31
module Database.PostgreSQL.Protocol.Codecs.Numeric
42
( scientificToNumeric
53
, numericToScientific

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy