Skip to content

Commit

Permalink
Merge pull request #56 from dalaing/develop
Browse files Browse the repository at this point in the history
Add vector column type
  • Loading branch information
ali-abrar authored Jan 16, 2024
2 parents bfed6c7 + b80f3cb commit 63064b3
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Unreleased

* Add ltree column type
* Add vector column type

## 0.1.4.0

Expand Down
3 changes: 3 additions & 0 deletions src/Database/Beam/AutoMigrate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,9 @@ renderDataType = \case
PgSpecificType PgOid -> "oid"
-- ltree
PgSpecificType PgLTree -> "ltree"
-- vector
PgSpecificType (PgVector Nothing) -> "vector"
PgSpecificType (PgVector (Just n)) -> mconcat ["vector(", T.pack . show $ n, ")"]
-- Arrays
SqlArrayType (SqlArrayType _ _) _ -> error "beam-automigrate: invalid nested array."
SqlArrayType _ 0 -> error "beam-automigrate: array with zero dimensions"
Expand Down
53 changes: 39 additions & 14 deletions src/Database/Beam/AutoMigrate/Postgres.hs
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,39 @@ referenceActionsQ =
"WHERE sch_child.nspname = current_schema() ORDER BY c.conname "
]

-- | Return the names and OIDs of all user defined types in the public namespace
--
-- This lets us work with types that come from extensions, regardless of when the extension is added.
-- Without this, the OIDs of these types could shift underneath us.
extensionTypeNamesQ :: Pg.Query
extensionTypeNamesQ =
fromString $
unlines
[ "SELECT ty.oid, ty.typname ",
"FROM pg_type ty ",
"INNER JOIN pg_namespace ns ON ty.typnamespace = ns.oid ",
"WHERE ns.nspname = 'public' AND ty.typcategory = 'U' "
]

-- | Connects to a running PostgreSQL database and extract the relevant 'Schema' out of it.
getSchema :: Pg.Connection -> IO Schema
getSchema conn = do
allTableConstraints <- getAllConstraints conn
allDefaults <- getAllDefaults conn
extensionTypeData <- Pg.fold_ conn extensionTypeNamesQ mempty getExtension
enumerationData <- Pg.fold_ conn enumerationsQ mempty getEnumeration
sequences <- Pg.fold_ conn sequencesQ mempty getSequence
tables <-
Pg.fold_ conn userTablesQ mempty (getTable allDefaults enumerationData allTableConstraints)
Pg.fold_ conn userTablesQ mempty (getTable allDefaults extensionTypeData enumerationData allTableConstraints)
pure $ Schema tables (M.fromList $ M.elems enumerationData) sequences
where
getExtension ::
Map Pg.Oid ExtensionTypeName ->
(Pg.Oid, Text) ->
IO (Map Pg.Oid ExtensionTypeName)
getExtension allExtensions (oid, name) =
pure $ M.insert oid (ExtensionTypeName name) allExtensions

getEnumeration ::
Map Pg.Oid (EnumerationName, Enumeration) ->
(Text, Pg.Oid, V.Vector Text) ->
Expand All @@ -232,26 +254,28 @@ getSchema conn = do

getTable ::
AllDefaults ->
Map Pg.Oid ExtensionTypeName ->
Map Pg.Oid (EnumerationName, Enumeration) ->
AllTableConstraints ->
Tables ->
(Pg.Oid, Text) ->
IO Tables
getTable allDefaults enumData allTableConstraints allTables (oid, TableName -> tName) = do
getTable allDefaults extensionTypeData enumData allTableConstraints allTables (oid, TableName -> tName) = do
pgColumns <- Pg.query conn tableColumnsQ (Pg.Only oid)
newTable <-
Table (fromMaybe noTableConstraints (M.lookup tName allTableConstraints))
<$> foldlM (getColumns tName enumData allDefaults) mempty pgColumns
<$> foldlM (getColumns tName extensionTypeData enumData allDefaults) mempty pgColumns
pure $ M.insert tName newTable allTables

getColumns ::
TableName ->
Map Pg.Oid ExtensionTypeName ->
Map Pg.Oid (EnumerationName, Enumeration) ->
AllDefaults ->
Columns ->
(ByteString, Pg.Oid, Int, Int, Bool, ByteString) ->
IO Columns
getColumns tName enumData defaultData c (attname, atttypid, atttypmod, attndims, attnotnull, format_type) = do
getColumns tName extensionTypeData enumData defaultData c (attname, atttypid, atttypmod, attndims, attnotnull, format_type) = do
-- /NOTA BENE(adn)/: The atttypmod - 4 was originally taken from 'beam-migrate'
-- (see: https://github.com/tathougies/beam/blob/d87120b58373df53f075d92ce12037a98ca709ab/beam-postgres/Database/Beam/Postgres/Migrate.hs#L343)
-- but there are cases where this is not correct, for example in the case of bitstrings.
Expand All @@ -271,9 +295,9 @@ getSchema conn = do

case asum
[ pgSerialTyColumnType atttypid mbDefault,
pgTypeToColumnType atttypid mbPrecision,
pgTypeToColumnType extensionTypeData atttypid mbPrecision,
pgEnumTypeToColumnType enumData atttypid,
pgArrayTypeToColumnType atttypid mbPrecision attndims
pgArrayTypeToColumnType extensionTypeData atttypid mbPrecision attndims
] of
Just cType -> do
let nullConstraint = if attnotnull then S.fromList [NotNull] else mempty
Expand Down Expand Up @@ -310,8 +334,8 @@ pgSerialTyColumnType _ _ = Nothing

-- | Tries to convert from a Postgres' 'Oid' into 'ColumnType'.
-- Mostly taken from [beam-migrate](Database.Beam.Postgres.Migrate).
pgTypeToColumnType :: Pg.Oid -> Maybe Int -> Maybe ColumnType
pgTypeToColumnType oid width
pgTypeToColumnType :: Map Pg.Oid ExtensionTypeName -> Pg.Oid -> Maybe Int -> Maybe ColumnType
pgTypeToColumnType extensionTypeData oid width
| Pg.typoid Pg.int2 == oid =
Just (SqlStdType smallIntType)
| Pg.typoid Pg.int4 == oid =
Expand Down Expand Up @@ -376,14 +400,15 @@ pgTypeToColumnType oid width
Just (PgSpecificType PgUuid)
| Pg.typoid Pg.oid == oid =
Just (PgSpecificType PgOid)
| Pg.Oid 16385 == oid =
| M.lookup oid extensionTypeData == Just "ltree" =
Just (PgSpecificType PgLTree)
| otherwise =
Nothing
| M.lookup oid extensionTypeData == Just "vector" =
Just (PgSpecificType . PgVector $ (+ 4) . fromIntegral <$> width)
| otherwise = Nothing

pgArrayTypeToColumnType :: Pg.Oid -> Maybe Int -> Int -> Maybe ColumnType
pgArrayTypeToColumnType oid width dims = case Pg.staticTypeInfo oid of
Just (Pg.Array _ _ _ _ subTypeInfo) -> case pgTypeToColumnType (Pg.typoid subTypeInfo) width of
pgArrayTypeToColumnType :: Map Pg.Oid ExtensionTypeName -> Pg.Oid -> Maybe Int -> Int -> Maybe ColumnType
pgArrayTypeToColumnType extensionTypeData oid width dims = case Pg.staticTypeInfo oid of
Just (Pg.Array _ _ _ _ subTypeInfo) -> case pgTypeToColumnType extensionTypeData (Pg.typoid subTypeInfo) width of
Just columnType -> Just $ SqlArrayType columnType (fromIntegral dims)
_ -> Nothing
_ -> Nothing
Expand Down
10 changes: 10 additions & 0 deletions src/Database/Beam/AutoMigrate/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import qualified Database.Beam.Backend.SQL.AST as AST
import Database.Beam.Postgres (Pg, Postgres)
import qualified Database.Beam.Postgres.Syntax as Syntax
import GHC.Generics hiding (to)
import Numeric.Natural (Natural)
import Lens.Micro (Lens', lens, to, _Right)
import Lens.Micro.Extras (preview)

Expand Down Expand Up @@ -158,13 +159,22 @@ data PgDataType
| PgEnumeration EnumerationName
| PgOid
| PgLTree
| PgVector (Maybe Natural)

deriving instance Show PgDataType

deriving instance Eq PgDataType

deriving instance Generic PgDataType

newtype ExtensionTypeName = ExtensionTypeName
{ extensionTypeName :: Text
}
deriving (Show, Eq, Ord, NFData, Generic)

instance IsString ExtensionTypeName where
fromString = ExtensionTypeName . T.pack

-- Newtype wrapper to be able to derive appropriate 'HasDefaultSqlDataType' for /Postgres/ enum types.
newtype PgEnum a
= PgEnum a
Expand Down

0 comments on commit 63064b3

Please sign in to comment.