Skip to content

Commit

Permalink
Add MonadBaseControl and MonadBase to beam-sqlite and beam-postgres
Browse files Browse the repository at this point in the history
# Conflicts:
#	beam-sqlite/beam-sqlite.cabal
  • Loading branch information
ryantrinkle committed Feb 22, 2024
1 parent a9c906d commit 9b6bf4f
Show file tree
Hide file tree
Showing 5 changed files with 419 additions and 6 deletions.
394 changes: 394 additions & 0 deletions beam-postgres/Database/Beam/Postgres/#Connection.hs#
Original file line number Diff line number Diff line change
@@ -0,0 +1,394 @@
{-# OPTIONS_GHC -fno-warn-orphans -fno-warn-partial-type-signatures #-}

{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}

module Database.Beam.Postgres.Connection
( Pg(..), PgF(..)

, liftIOWithHandle

, runBeamPostgres, runBeamPostgresDebug

, pgRenderSyntax, runPgRowReader, getFields

, withPgDebug

, postgresUriSyntax ) where

import Control.Exception (SomeException(..), throwIO)
import Control.Monad.Base (MonadBase(..))
import Control.Monad.Free.Church
import Control.Monad.IO.Class
import Control.Monad.Trans.Control (MonadBaseControl(..))

import Database.Beam hiding (runDelete, runUpdate, runInsert, insert)
import Database.Beam.Backend.SQL.BeamExtensions
import Database.Beam.Backend.SQL.Row ( FromBackendRowF(..), FromBackendRowM(..)
, BeamRowReadError(..), ColumnParseError(..) )
import Database.Beam.Backend.URI
import Database.Beam.Schema.Tables

import Database.Beam.Postgres.Syntax
import Database.Beam.Postgres.Full
import Database.Beam.Postgres.Types

import qualified Database.PostgreSQL.LibPQ as Pg hiding
(Connection, escapeStringConn, escapeIdentifier, escapeByteaConn, exec)
import qualified Database.PostgreSQL.Simple as Pg
import qualified Database.PostgreSQL.Simple.FromField as Pg
import qualified Database.PostgreSQL.Simple.Internal as Pg
( Field(..), RowParser(..)
, escapeStringConn, escapeIdentifier, escapeByteaConn
, exec, throwResultError )
import qualified Database.PostgreSQL.Simple.Internal as PgI
import qualified Database.PostgreSQL.Simple.Ok as Pg
import qualified Database.PostgreSQL.Simple.Types as Pg (Query(..))

import Control.Monad.Reader
import Control.Monad.State
import qualified Control.Monad.Fail as Fail

import Data.ByteString (ByteString)
import Data.ByteString.Builder (toLazyByteString, byteString)
import qualified Data.ByteString.Lazy as BL
import Data.Maybe (listToMaybe, fromMaybe)
import Data.Proxy
import Data.String
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8)
import Data.Typeable (cast)
#if !MIN_VERSION_base(4, 11, 0)
import Data.Semigroup
#endif

import Foreign.C.Types

import Network.URI (uriToString)

data PgStream a = PgStreamDone (Either BeamRowReadError a)
| PgStreamContinue (Maybe PgI.Row -> IO (PgStream a))

-- | 'BeamURIOpeners' for the standard @postgresql:@ URI scheme. See the
-- postgres documentation for more details on the formatting. See documentation
-- for 'BeamURIOpeners' for more information on how to use this with beam
postgresUriSyntax :: c Postgres Pg.Connection Pg
-> BeamURIOpeners c
postgresUriSyntax =
mkUriOpener runBeamPostgres "postgresql:"
(\uri -> do
let pgConnStr = fromString (uriToString id uri "")
hdl <- Pg.connectPostgreSQL pgConnStr
pure (hdl, Pg.close hdl))

-- * Syntax rendering

pgRenderSyntax ::
Pg.Connection -> PgSyntax -> IO ByteString
pgRenderSyntax conn (PgSyntax mkQuery) =
renderBuilder <$> runF mkQuery finish step mempty
where
renderBuilder = BL.toStrict . toLazyByteString

step (EmitBuilder b next) a = next (a <> b)
step (EmitByteString b next) a = next (a <> byteString b)
step (EscapeString b next) a = do
res <- wrapError "EscapeString" (Pg.escapeStringConn conn b)
next (a <> byteString res)
step (EscapeBytea b next) a = do
res <- wrapError "EscapeBytea" (Pg.escapeByteaConn conn b)
next (a <> byteString res)
step (EscapeIdentifier b next) a = do
res <- wrapError "EscapeIdentifier" (Pg.escapeIdentifier conn b)
next (a <> byteString res)

finish _ = pure

wrapError step' go = do
res <- go
case res of
Right res' -> pure res'
Left res' -> fail (step' <> ": " <> show res')

-- * Run row readers

getFields :: Pg.Result -> IO [Pg.Field]
getFields res = do
Pg.Col colCount <- Pg.nfields res

let getField col =
Pg.Field res (Pg.Col col) <$> Pg.ftype res (Pg.Col col)

mapM getField [0..colCount - 1]

runPgRowReader ::
Pg.Connection -> Pg.Row -> Pg.Result -> [Pg.Field] -> FromBackendRowM Postgres a -> IO (Either BeamRowReadError a)
runPgRowReader conn rowIdx res fields (FromBackendRowM readRow) =
Pg.nfields res >>= \(Pg.Col colCount) ->
runF readRow finish step 0 colCount fields
where

step :: forall x. FromBackendRowF Postgres (CInt -> CInt -> [PgI.Field] -> IO (Either BeamRowReadError x))
-> CInt -> CInt -> [PgI.Field] -> IO (Either BeamRowReadError x)
step (ParseOneField _) curCol colCount [] = pure (Left (BeamRowReadError (Just (fromIntegral curCol)) (ColumnNotEnoughColumns (fromIntegral colCount))))
step (ParseOneField _) curCol colCount _
| curCol >= colCount = pure (Left (BeamRowReadError (Just (fromIntegral curCol)) (ColumnNotEnoughColumns (fromIntegral colCount))))
step (ParseOneField (next' :: next -> _)) curCol colCount (field:remainingFields) =
do fieldValue <- Pg.getvalue res rowIdx (Pg.Col curCol)
res' <- Pg.runConversion (Pg.fromField field fieldValue) conn
case res' of
Pg.Errors errs ->
let err = fromMaybe (ColumnErrorInternal "Column parse failed with unknown exception") $
listToMaybe $
do SomeException e <- errs
Just pgErr <- pure (cast e)
case pgErr of
Pg.ConversionFailed { Pg.errSQLType = sql
, Pg.errHaskellType = hs
, Pg.errMessage = msg } ->
pure (ColumnTypeMismatch hs sql msg)
Pg.Incompatible { Pg.errSQLType = sql
, Pg.errHaskellType = hs
, Pg.errMessage = msg } ->
pure (ColumnTypeMismatch hs sql msg)
Pg.UnexpectedNull {} ->
pure ColumnUnexpectedNull
in pure (Left (BeamRowReadError (Just (fromIntegral curCol)) err))
Pg.Ok x -> next' x (curCol + 1) colCount remainingFields

step (Alt (FromBackendRowM a) (FromBackendRowM b) next) curCol colCount cols =
do aRes <- runF a (\x curCol' colCount' cols' -> pure (Right (next x curCol' colCount' cols'))) step curCol colCount cols
case aRes of
Right next' -> next'
Left aErr -> do
bRes <- runF b (\x curCol' colCount' cols' -> pure (Right (next x curCol' colCount' cols'))) step curCol colCount cols
case bRes of
Right next' -> next'
Left {} -> pure (Left aErr)

step (FailParseWith err) _ _ _ =
pure (Left err)

finish x _ _ _ = pure (Right x)

withPgDebug :: (String -> IO ()) -> Pg.Connection -> Pg a -> IO (Either BeamRowReadError a)
withPgDebug dbg conn (Pg action) =
let finish x = pure (Right x)
step (PgLiftIO io next) = io >>= next
step (PgLiftWithHandle withConn next) = withConn dbg conn >>= next
step (PgFetchNext next) = next Nothing
step (PgRunReturning CursorBatching
(PgCommandSyntax PgCommandTypeQuery syntax)
(mkProcess :: Pg (Maybe x) -> Pg a')
next) =
do query <- pgRenderSyntax conn syntax
let Pg process = mkProcess (Pg (liftF (PgFetchNext id)))
dbg (T.unpack (decodeUtf8 query))
action' <- runF process finishProcess stepProcess Nothing
case action' of
PgStreamDone (Right x) -> Pg.execute_ conn (Pg.Query query) >> next x
PgStreamDone (Left err) -> pure (Left err)
PgStreamContinue nextStream ->
let finishUp (PgStreamDone (Right x)) = next x
finishUp (PgStreamDone (Left err)) = pure (Left err)
finishUp (PgStreamContinue next') = next' Nothing >>= finishUp

columnCount = fromIntegral $ valuesNeeded (Proxy @Postgres) (Proxy @x)
in Pg.foldWith_ (Pg.RP (put columnCount >> ask)) conn (Pg.Query query) (PgStreamContinue nextStream) runConsumer >>= finishUp
step (PgRunReturning AtOnce
(PgCommandSyntax PgCommandTypeQuery syntax)
(mkProcess :: Pg (Maybe x) -> Pg a')
next) =
renderExecReturningList "No tuples returned to Postgres query" syntax mkProcess next
step (PgRunReturning _ (PgCommandSyntax PgCommandTypeDataUpdateReturning syntax) mkProcess next) =
renderExecReturningList "No tuples returned to Postgres update/insert returning" syntax mkProcess next
step (PgRunReturning _ (PgCommandSyntax _ syntax) mkProcess next) =
do query <- pgRenderSyntax conn syntax
dbg (T.unpack (decodeUtf8 query))
_ <- Pg.execute_ conn (Pg.Query query)

let Pg process = mkProcess (Pg (liftF (PgFetchNext id)))
runF process next stepReturningNone

renderExecReturningList :: (FromBackendRow Postgres x) => _ -> PgSyntax -> (Pg (Maybe x) -> Pg a') -> _ -> _
renderExecReturningList errMsg syntax mkProcess next =
do query <- pgRenderSyntax conn syntax
dbg (T.unpack (decodeUtf8 query))

res <- Pg.exec conn query
sts <- Pg.resultStatus res
case sts of
Pg.TuplesOk -> do
let Pg process = mkProcess (Pg (liftF (PgFetchNext id)))
runF process (\x _ -> Pg.unsafeFreeResult res >> next x) (stepReturningList res) 0
_ -> Pg.throwResultError errMsg res sts

stepReturningNone :: forall a. PgF (IO (Either BeamRowReadError a)) -> IO (Either BeamRowReadError a)
stepReturningNone (PgLiftIO action' next) = action' >>= next
stepReturningNone (PgLiftWithHandle withConn next) = withConn dbg conn >>= next
stepReturningNone (PgFetchNext next) = next Nothing
stepReturningNone (PgRunReturning {}) = pure (Left (BeamRowReadError Nothing (ColumnErrorInternal "Nested queries not allowed")))

stepReturningList :: forall a. Pg.Result -> PgF (CInt -> IO (Either BeamRowReadError a)) -> CInt -> IO (Either BeamRowReadError a)
stepReturningList _ (PgLiftIO action' next) rowIdx = action' >>= \x -> next x rowIdx
stepReturningList res (PgFetchNext next) rowIdx =
do fields <- getFields res
Pg.Row rowCount <- Pg.ntuples res
if rowIdx >= rowCount
then next Nothing rowIdx
else runPgRowReader conn (Pg.Row rowIdx) res fields fromBackendRow >>= \case
Left err -> pure (Left err)
Right r -> next (Just r) (rowIdx + 1)
stepReturningList _ (PgRunReturning {}) _ = pure (Left (BeamRowReadError Nothing (ColumnErrorInternal "Nested queries not allowed")))
stepReturningList _ (PgLiftWithHandle {}) _ = pure (Left (BeamRowReadError Nothing (ColumnErrorInternal "Nested queries not allowed")))

finishProcess :: forall a. a -> Maybe PgI.Row -> IO (PgStream a)
finishProcess x _ = pure (PgStreamDone (Right x))

stepProcess :: forall a. PgF (Maybe PgI.Row -> IO (PgStream a)) -> Maybe PgI.Row -> IO (PgStream a)
stepProcess (PgLiftIO action' next) row = action' >>= flip next row
stepProcess (PgFetchNext next) Nothing =
pure . PgStreamContinue $ \res ->
case res of
Nothing -> next Nothing Nothing
Just (PgI.Row rowIdx res') ->
getFields res' >>= \fields ->
runPgRowReader conn rowIdx res' fields fromBackendRow >>= \case
Left err -> pure (PgStreamDone (Left err))
Right r -> next (Just r) Nothing
stepProcess (PgFetchNext next) (Just (PgI.Row rowIdx res)) =
getFields res >>= \fields ->
runPgRowReader conn rowIdx res fields fromBackendRow >>= \case
Left err -> pure (PgStreamDone (Left err))
Right r -> pure (PgStreamContinue (next (Just r)))
stepProcess (PgRunReturning {}) _ = pure (PgStreamDone (Left (BeamRowReadError Nothing (ColumnErrorInternal "Nested queries not allowed"))))
stepProcess (PgLiftWithHandle _ _) _ = pure (PgStreamDone (Left (BeamRowReadError Nothing (ColumnErrorInternal "Nested queries not allowed"))))

runConsumer :: forall a. PgStream a -> PgI.Row -> IO (PgStream a)
runConsumer s@(PgStreamDone {}) _ = pure s
runConsumer (PgStreamContinue next) row = next (Just row)
in runF action finish step

-- * Beam Monad class

data PgF next where
PgLiftIO :: IO a -> (a -> next) -> PgF next
PgRunReturning ::
FromBackendRow Postgres x =>
FetchMode -> PgCommandSyntax -> (Pg (Maybe x) -> Pg a) -> (a -> next) -> PgF next
PgFetchNext ::
FromBackendRow Postgres x =>
(Maybe x -> next) -> PgF next
PgLiftWithHandle :: ((String -> IO ()) -> Pg.Connection -> IO a) -> (a -> next) -> PgF next
deriving instance Functor PgF

-- | How to fetch results.
data FetchMode
= CursorBatching -- ^ Fetch in batches of ~256 rows via cursor for SELECT.
| AtOnce -- ^ Fetch all rows at once.

-- | 'MonadBeam' in which we can run Postgres commands. See the documentation
-- for 'MonadBeam' on examples of how to use.
--
-- @beam-postgres@ also provides functions that let you run queries without
-- 'MonadBeam'. These functions may be more efficient and offer a conduit
-- API. See "Database.Beam.Postgres.Conduit" for more information.
newtype Pg a = Pg { runPg :: F PgF a }
deriving (Monad, Applicative, Functor, MonadFree PgF)

instance Fail.MonadFail Pg where
fail e = liftIO (Fail.fail $ "Internal Error with: " <> show e)

instance MonadIO Pg where
liftIO x = liftF (PgLiftIO x id)

instance MonadBase IO Pg where
liftBase = liftIO

instance MonadBaseControl IO Pg where
type StM Pg a = a

liftBaseWith action =
liftF (PgLiftWithHandle (\dbg conn -> action (runBeamPostgresDebug dbg conn)) id)

restoreM = pure

liftIOWithHandle :: (Pg.Connection -> IO a) -> Pg a
liftIOWithHandle f = liftF (PgLiftWithHandle (\_ -> f) id)

runBeamPostgresDebug :: (String -> IO ()) -> Pg.Connection -> Pg a -> IO a
runBeamPostgresDebug dbg conn action =
withPgDebug dbg conn action >>= either throwIO pure

runBeamPostgres :: Pg.Connection -> Pg a -> IO a
runBeamPostgres = runBeamPostgresDebug (\_ -> pure ())

instance MonadBeam Postgres Pg where
runReturningMany cmd consume =
liftF (PgRunReturning CursorBatching cmd consume id)

runReturningOne cmd =
liftF (PgRunReturning AtOnce cmd consume id)
where
consume next = do
a <- next
case a of
Nothing -> pure Nothing
Just x -> do
a' <- next
case a' of
Nothing -> pure (Just x)
Just _ -> pure Nothing

runReturningList cmd =
liftF (PgRunReturning AtOnce cmd consume id)
where
consume next =
let collectM acc = do
a <- next
case a of
Nothing -> pure (acc [])
Just x -> collectM (acc . (x:))
in collectM id

instance MonadBeamInsertReturning Postgres Pg where
runInsertReturningList i = do
let insertReturningCmd' = i `returning`
changeBeamRep (\(Columnar' (QExpr s) :: Columnar' (QExpr Postgres PostgresInaccessible) ty) ->
Columnar' (QExpr s) :: Columnar' (QExpr Postgres ()) ty)

-- Make savepoint
case insertReturningCmd' of
PgInsertReturningEmpty ->
pure []
PgInsertReturning insertReturningCmd ->
runReturningList (PgCommandSyntax PgCommandTypeDataUpdateReturning insertReturningCmd)

instance MonadBeamUpdateReturning Postgres Pg where
runUpdateReturningList u = do
let updateReturningCmd' = u `returning`
changeBeamRep (\(Columnar' (QExpr s) :: Columnar' (QExpr Postgres PostgresInaccessible) ty) ->
Columnar' (QExpr s) :: Columnar' (QExpr Postgres ()) ty)

case updateReturningCmd' of
PgUpdateReturningEmpty ->
pure []
PgUpdateReturning updateReturningCmd ->
runReturningList (PgCommandSyntax PgCommandTypeDataUpdateReturning updateReturningCmd)

instance MonadBeamDeleteReturning Postgres Pg where
runDeleteReturningList d = do
let PgDeleteReturning deleteReturningCmd = d `returning`
changeBeamRep (\(Columnar' (QExpr s) :: Columnar' (QExpr Postgres PostgresInaccessible) ty) ->
Columnar' (QExpr s) :: Columnar' (QExpr Postgres ()) ty)

runReturningList (PgCommandSyntax PgCommandTypeDataUpdateReturning deleteReturningCmd)
Loading

0 comments on commit 9b6bf4f

Please sign in to comment.