From b80f3cbf28abc2c71fe3bdbbb70a9368593dbe34 Mon Sep 17 00:00:00 2001 From: Dave Laing Date: Mon, 13 Nov 2023 16:36:48 +1000 Subject: [PATCH] Add vector column type --- CHANGELOG.md | 1 + src/Database/Beam/AutoMigrate.hs | 3 ++ src/Database/Beam/AutoMigrate/Postgres.hs | 53 +++++++++++++++++------ src/Database/Beam/AutoMigrate/Types.hs | 10 +++++ 4 files changed, 53 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d544ce..46d4640 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Unreleased * Add ltree column type +* Add vector column type ## 0.1.4.0 diff --git a/src/Database/Beam/AutoMigrate.hs b/src/Database/Beam/AutoMigrate.hs index dbcf999..920087c 100644 --- a/src/Database/Beam/AutoMigrate.hs +++ b/src/Database/Beam/AutoMigrate.hs @@ -571,6 +571,9 @@ renderDataType = \case PgSpecificType PgOid -> "oid" -- ltree PgSpecificType PgLTree -> "ltree" + -- vector + PgSpecificType (PgVector Nothing) -> "vector" + PgSpecificType (PgVector (Just n)) -> mconcat ["vector(", T.pack . show $ n, ")"] -- Arrays SqlArrayType (SqlArrayType _ _) _ -> error "beam-automigrate: invalid nested array." SqlArrayType _ 0 -> error "beam-automigrate: array with zero dimensions" diff --git a/src/Database/Beam/AutoMigrate/Postgres.hs b/src/Database/Beam/AutoMigrate/Postgres.hs index ffe52f3..2653fc2 100644 --- a/src/Database/Beam/AutoMigrate/Postgres.hs +++ b/src/Database/Beam/AutoMigrate/Postgres.hs @@ -202,17 +202,39 @@ referenceActionsQ = "WHERE sch_child.nspname = current_schema() ORDER BY c.conname " ] +-- | Return the names and OIDs of all user defined types in the public namespace +-- +-- This lets us work with types that come from extensions, regardless of when the extension is added. +-- Without this, the OIDs of these types could shift underneath us. +extensionTypeNamesQ :: Pg.Query +extensionTypeNamesQ = + fromString $ + unlines + [ "SELECT ty.oid, ty.typname ", + "FROM pg_type ty ", + "INNER JOIN pg_namespace ns ON ty.typnamespace = ns.oid ", + "WHERE ns.nspname = 'public' AND ty.typcategory = 'U' " + ] + -- | Connects to a running PostgreSQL database and extract the relevant 'Schema' out of it. getSchema :: Pg.Connection -> IO Schema getSchema conn = do allTableConstraints <- getAllConstraints conn allDefaults <- getAllDefaults conn + extensionTypeData <- Pg.fold_ conn extensionTypeNamesQ mempty getExtension enumerationData <- Pg.fold_ conn enumerationsQ mempty getEnumeration sequences <- Pg.fold_ conn sequencesQ mempty getSequence tables <- - Pg.fold_ conn userTablesQ mempty (getTable allDefaults enumerationData allTableConstraints) + Pg.fold_ conn userTablesQ mempty (getTable allDefaults extensionTypeData enumerationData allTableConstraints) pure $ Schema tables (M.fromList $ M.elems enumerationData) sequences where + getExtension :: + Map Pg.Oid ExtensionTypeName -> + (Pg.Oid, Text) -> + IO (Map Pg.Oid ExtensionTypeName) + getExtension allExtensions (oid, name) = + pure $ M.insert oid (ExtensionTypeName name) allExtensions + getEnumeration :: Map Pg.Oid (EnumerationName, Enumeration) -> (Text, Pg.Oid, V.Vector Text) -> @@ -232,26 +254,28 @@ getSchema conn = do getTable :: AllDefaults -> + Map Pg.Oid ExtensionTypeName -> Map Pg.Oid (EnumerationName, Enumeration) -> AllTableConstraints -> Tables -> (Pg.Oid, Text) -> IO Tables - getTable allDefaults enumData allTableConstraints allTables (oid, TableName -> tName) = do + getTable allDefaults extensionTypeData enumData allTableConstraints allTables (oid, TableName -> tName) = do pgColumns <- Pg.query conn tableColumnsQ (Pg.Only oid) newTable <- Table (fromMaybe noTableConstraints (M.lookup tName allTableConstraints)) - <$> foldlM (getColumns tName enumData allDefaults) mempty pgColumns + <$> foldlM (getColumns tName extensionTypeData enumData allDefaults) mempty pgColumns pure $ M.insert tName newTable allTables getColumns :: TableName -> + Map Pg.Oid ExtensionTypeName -> Map Pg.Oid (EnumerationName, Enumeration) -> AllDefaults -> Columns -> (ByteString, Pg.Oid, Int, Int, Bool, ByteString) -> IO Columns - getColumns tName enumData defaultData c (attname, atttypid, atttypmod, attndims, attnotnull, format_type) = do + getColumns tName extensionTypeData enumData defaultData c (attname, atttypid, atttypmod, attndims, attnotnull, format_type) = do -- /NOTA BENE(adn)/: The atttypmod - 4 was originally taken from 'beam-migrate' -- (see: https://github.com/tathougies/beam/blob/d87120b58373df53f075d92ce12037a98ca709ab/beam-postgres/Database/Beam/Postgres/Migrate.hs#L343) -- but there are cases where this is not correct, for example in the case of bitstrings. @@ -271,9 +295,9 @@ getSchema conn = do case asum [ pgSerialTyColumnType atttypid mbDefault, - pgTypeToColumnType atttypid mbPrecision, + pgTypeToColumnType extensionTypeData atttypid mbPrecision, pgEnumTypeToColumnType enumData atttypid, - pgArrayTypeToColumnType atttypid mbPrecision attndims + pgArrayTypeToColumnType extensionTypeData atttypid mbPrecision attndims ] of Just cType -> do let nullConstraint = if attnotnull then S.fromList [NotNull] else mempty @@ -310,8 +334,8 @@ pgSerialTyColumnType _ _ = Nothing -- | Tries to convert from a Postgres' 'Oid' into 'ColumnType'. -- Mostly taken from [beam-migrate](Database.Beam.Postgres.Migrate). -pgTypeToColumnType :: Pg.Oid -> Maybe Int -> Maybe ColumnType -pgTypeToColumnType oid width +pgTypeToColumnType :: Map Pg.Oid ExtensionTypeName -> Pg.Oid -> Maybe Int -> Maybe ColumnType +pgTypeToColumnType extensionTypeData oid width | Pg.typoid Pg.int2 == oid = Just (SqlStdType smallIntType) | Pg.typoid Pg.int4 == oid = @@ -376,14 +400,15 @@ pgTypeToColumnType oid width Just (PgSpecificType PgUuid) | Pg.typoid Pg.oid == oid = Just (PgSpecificType PgOid) - | Pg.Oid 16385 == oid = + | M.lookup oid extensionTypeData == Just "ltree" = Just (PgSpecificType PgLTree) - | otherwise = - Nothing + | M.lookup oid extensionTypeData == Just "vector" = + Just (PgSpecificType . PgVector $ (+ 4) . fromIntegral <$> width) + | otherwise = Nothing -pgArrayTypeToColumnType :: Pg.Oid -> Maybe Int -> Int -> Maybe ColumnType -pgArrayTypeToColumnType oid width dims = case Pg.staticTypeInfo oid of - Just (Pg.Array _ _ _ _ subTypeInfo) -> case pgTypeToColumnType (Pg.typoid subTypeInfo) width of +pgArrayTypeToColumnType :: Map Pg.Oid ExtensionTypeName -> Pg.Oid -> Maybe Int -> Int -> Maybe ColumnType +pgArrayTypeToColumnType extensionTypeData oid width dims = case Pg.staticTypeInfo oid of + Just (Pg.Array _ _ _ _ subTypeInfo) -> case pgTypeToColumnType extensionTypeData (Pg.typoid subTypeInfo) width of Just columnType -> Just $ SqlArrayType columnType (fromIntegral dims) _ -> Nothing _ -> Nothing diff --git a/src/Database/Beam/AutoMigrate/Types.hs b/src/Database/Beam/AutoMigrate/Types.hs index 54dc170..a4ce267 100644 --- a/src/Database/Beam/AutoMigrate/Types.hs +++ b/src/Database/Beam/AutoMigrate/Types.hs @@ -23,6 +23,7 @@ import qualified Database.Beam.Backend.SQL.AST as AST import Database.Beam.Postgres (Pg, Postgres) import qualified Database.Beam.Postgres.Syntax as Syntax import GHC.Generics hiding (to) +import Numeric.Natural (Natural) import Lens.Micro (Lens', lens, to, _Right) import Lens.Micro.Extras (preview) @@ -158,6 +159,7 @@ data PgDataType | PgEnumeration EnumerationName | PgOid | PgLTree + | PgVector (Maybe Natural) deriving instance Show PgDataType @@ -165,6 +167,14 @@ deriving instance Eq PgDataType deriving instance Generic PgDataType +newtype ExtensionTypeName = ExtensionTypeName + { extensionTypeName :: Text + } + deriving (Show, Eq, Ord, NFData, Generic) + +instance IsString ExtensionTypeName where + fromString = ExtensionTypeName . T.pack + -- Newtype wrapper to be able to derive appropriate 'HasDefaultSqlDataType' for /Postgres/ enum types. newtype PgEnum a = PgEnum a