Skip to content

Commit

Permalink
feat: Add schema support to pgvector document store. (#1095)
Browse files Browse the repository at this point in the history
* Add schema support for the pgvector document store.

Using the public schema of a PostgreSQL database is an anti-pattern. This change adds support for using a schema other than the public schema to create tables.

* Fix long lines.

* Fix long lines. Remove trailing spaces.

* Fix trailing spaces.

* Fix last trailing space.

* Fix ruff issues.

* Fix trailing space.

* small fixes

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
  • Loading branch information
rblst and anakin87 authored Nov 14, 2024
1 parent 1ef03c0 commit 3b33958
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
logger = logging.getLogger(__name__)

CREATE_TABLE_STATEMENT = """
CREATE TABLE IF NOT EXISTS {table_name} (
CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (
id VARCHAR(128) PRIMARY KEY,
embedding VECTOR({embedding_dimension}),
content TEXT,
Expand All @@ -36,7 +36,7 @@
"""

INSERT_STATEMENT = """
INSERT INTO {table_name}
INSERT INTO {schema_name}.{table_name}
(id, embedding, content, dataframe, blob_data, blob_meta, blob_mime_type, meta)
VALUES (%(id)s, %(embedding)s, %(content)s, %(dataframe)s, %(blob_data)s, %(blob_meta)s, %(blob_mime_type)s, %(meta)s)
"""
Expand All @@ -54,7 +54,7 @@

KEYWORD_QUERY = """
SELECT {table_name}.*, ts_rank_cd(to_tsvector({language}, content), query) AS score
FROM {table_name}, plainto_tsquery({language}, %s) query
FROM {schema_name}.{table_name}, plainto_tsquery({language}, %s) query
WHERE to_tsvector({language}, content) @@ query
"""

Expand All @@ -78,6 +78,7 @@ def __init__(
self,
*,
connection_string: Secret = Secret.from_env_var("PG_CONN_STR"),
schema_name: str = "public",
table_name: str = "haystack_documents",
language: str = "english",
embedding_dimension: int = 768,
Expand All @@ -101,6 +102,7 @@ def __init__(
e.g.: `PG_CONN_STR="host=HOST port=PORT dbname=DBNAME user=USER password=PASSWORD"`
See [PostgreSQL Documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING)
for more details.
:param schema_name: The name of the schema the table is created in. The schema must already exist.
:param table_name: The name of the table to use to store Haystack documents.
:param language: The language to be used to parse query and document content in keyword retrieval.
To see the list of available languages, you can run the following SQL query in your PostgreSQL database:
Expand Down Expand Up @@ -137,6 +139,7 @@ def __init__(

self.connection_string = connection_string
self.table_name = table_name
self.schema_name = schema_name
self.embedding_dimension = embedding_dimension
if vector_function not in VALID_VECTOR_FUNCTIONS:
msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}, but got {vector_function}"
Expand Down Expand Up @@ -207,6 +210,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
connection_string=self.connection_string.to_dict(),
schema_name=self.schema_name,
table_name=self.table_name,
embedding_dimension=self.embedding_dimension,
vector_function=self.vector_function,
Expand Down Expand Up @@ -266,36 +270,45 @@ def _create_table_if_not_exists(self):
"""

create_sql = SQL(CREATE_TABLE_STATEMENT).format(
table_name=Identifier(self.table_name), embedding_dimension=SQLLiteral(self.embedding_dimension)
schema_name=Identifier(self.schema_name),
table_name=Identifier(self.table_name),
embedding_dimension=SQLLiteral(self.embedding_dimension),
)

self._execute_sql(create_sql, error_msg="Could not create table in PgvectorDocumentStore")

def delete_table(self):
"""
Deletes the table used to store Haystack documents.
The name of the table (`table_name`) is defined when initializing the `PgvectorDocumentStore`.
The name of the schema (`schema_name`) and the name of the table (`table_name`)
are defined when initializing the `PgvectorDocumentStore`.
"""
delete_sql = SQL("DROP TABLE IF EXISTS {schema_name}.{table_name}").format(
schema_name=Identifier(self.schema_name),
table_name=Identifier(self.table_name),
)

delete_sql = SQL("DROP TABLE IF EXISTS {table_name}").format(table_name=Identifier(self.table_name))

self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore")
self._execute_sql(
delete_sql,
error_msg=f"Could not delete table {self.schema_name}.{self.table_name} in PgvectorDocumentStore",
)

def _create_keyword_index_if_not_exists(self):
"""
Internal method to create the keyword index if not exists.
"""
index_exists = bool(
self._execute_sql(
"SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s",
(self.table_name, self.keyword_index_name),
"SELECT 1 FROM pg_indexes WHERE schemaname = %s AND tablename = %s AND indexname = %s",
(self.schema_name, self.table_name, self.keyword_index_name),
"Could not check if keyword index exists",
).fetchone()
)

sql_create_index = SQL(
"CREATE INDEX {index_name} ON {table_name} USING GIN (to_tsvector({language}, content))"
"CREATE INDEX {index_name} ON {schema_name}.{table_name} USING GIN (to_tsvector({language}, content))"
).format(
schema_name=Identifier(self.schema_name),
index_name=Identifier(self.keyword_index_name),
table_name=Identifier(self.table_name),
language=SQLLiteral(self.language),
Expand All @@ -318,8 +331,8 @@ def _handle_hnsw(self):

index_exists = bool(
self._execute_sql(
"SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s",
(self.table_name, self.hnsw_index_name),
"SELECT 1 FROM pg_indexes WHERE schemaname = %s AND tablename = %s AND indexname = %s",
(self.schema_name, self.table_name, self.hnsw_index_name),
"Could not check if HNSW index exists",
).fetchone()
)
Expand Down Expand Up @@ -349,8 +362,13 @@ def _create_hnsw_index(self):
if key in HNSW_INDEX_CREATION_VALID_KWARGS
}

sql_create_index = SQL("CREATE INDEX {index_name} ON {table_name} USING hnsw (embedding {ops}) ").format(
index_name=Identifier(self.hnsw_index_name), table_name=Identifier(self.table_name), ops=SQL(pg_ops)
sql_create_index = SQL(
"CREATE INDEX {index_name} ON {schema_name}.{table_name} USING hnsw (embedding {ops}) "
).format(
schema_name=Identifier(self.schema_name),
index_name=Identifier(self.hnsw_index_name),
table_name=Identifier(self.table_name),
ops=SQL(pg_ops),
)

if actual_hnsw_index_creation_kwargs:
Expand All @@ -369,7 +387,9 @@ def count_documents(self) -> int:
Returns how many documents are present in the document store.
"""

sql_count = SQL("SELECT COUNT(*) FROM {table_name}").format(table_name=Identifier(self.table_name))
sql_count = SQL("SELECT COUNT(*) FROM {schema_name}.{table_name}").format(
schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name)
)

count = self._execute_sql(sql_count, error_msg="Could not count documents in PgvectorDocumentStore").fetchone()[
0
Expand All @@ -395,7 +415,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
raise ValueError(msg)

sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name))
sql_filter = SQL("SELECT * FROM {schema_name}.{table_name}").format(
schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name)
)

params = ()
if filters:
Expand Down Expand Up @@ -434,7 +456,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D

db_documents = self._from_haystack_to_pg_documents(documents)

sql_insert = SQL(INSERT_STATEMENT).format(table_name=Identifier(self.table_name))
sql_insert = SQL(INSERT_STATEMENT).format(
schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name)
)

if policy == DuplicatePolicy.OVERWRITE:
sql_insert += SQL(UPDATE_STATEMENT)
Expand Down Expand Up @@ -543,8 +567,10 @@ def delete_documents(self, document_ids: List[str]) -> None:

document_ids_str = ", ".join(f"'{document_id}'" for document_id in document_ids)

delete_sql = SQL("DELETE FROM {table_name} WHERE id IN ({document_ids_str})").format(
table_name=Identifier(self.table_name), document_ids_str=SQL(document_ids_str)
delete_sql = SQL("DELETE FROM {schema_name}.{table_name} WHERE id IN ({document_ids_str})").format(
schema_name=Identifier(self.schema_name),
table_name=Identifier(self.table_name),
document_ids_str=SQL(document_ids_str),
)

self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore")
Expand All @@ -570,6 +596,7 @@ def _keyword_retrieval(
raise ValueError(msg)

sql_select = SQL(KEYWORD_QUERY).format(
schema_name=Identifier(self.schema_name),
table_name=Identifier(self.table_name),
language=SQLLiteral(self.language),
query=SQLLiteral(query),
Expand Down Expand Up @@ -643,7 +670,8 @@ def _embedding_retrieval(
elif vector_function == "l2_distance":
score_definition = f"embedding <-> {query_embedding_for_postgres} AS score"

sql_select = SQL("SELECT *, {score} FROM {table_name}").format(
sql_select = SQL("SELECT *, {score} FROM {schema_name}.{table_name}").format(
schema_name=Identifier(self.schema_name),
table_name=Identifier(self.table_name),
score=SQL(score_definition),
)
Expand Down
3 changes: 3 additions & 0 deletions integrations/pgvector/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_init(monkeypatch):
monkeypatch.setenv("PG_CONN_STR", "some_connection_string")

document_store = PgvectorDocumentStore(
schema_name="my_schema",
table_name="my_table",
embedding_dimension=512,
vector_function="l2_distance",
Expand All @@ -59,6 +60,7 @@ def test_init(monkeypatch):
keyword_index_name="my_keyword_index",
)

assert document_store.schema_name == "my_schema"
assert document_store.table_name == "my_table"
assert document_store.embedding_dimension == 512
assert document_store.vector_function == "l2_distance"
Expand Down Expand Up @@ -93,6 +95,7 @@ def test_to_dict(monkeypatch):
"init_parameters": {
"connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"},
"table_name": "my_table",
"schema_name": "public",
"embedding_dimension": 512,
"vector_function": "l2_distance",
"recreate_table": True,
Expand Down
2 changes: 2 additions & 0 deletions integrations/pgvector/tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_to_dict(self, mock_store):
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore",
"init_parameters": {
"connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"},
"schema_name": "public",
"table_name": "haystack",
"embedding_dimension": 768,
"vector_function": "cosine_similarity",
Expand Down Expand Up @@ -175,6 +176,7 @@ def test_to_dict(self, mock_store):
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore",
"init_parameters": {
"connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"},
"schema_name": "public",
"table_name": "haystack",
"embedding_dimension": 768,
"vector_function": "cosine_similarity",
Expand Down

0 comments on commit 3b33958

Please sign in to comment.