diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 1b1333f5c..8e9c0f2fc 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) CREATE_TABLE_STATEMENT = """ -CREATE TABLE IF NOT EXISTS {table_name} ( +CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ( id VARCHAR(128) PRIMARY KEY, embedding VECTOR({embedding_dimension}), content TEXT, @@ -36,7 +36,7 @@ """ INSERT_STATEMENT = """ -INSERT INTO {table_name} +INSERT INTO {schema_name}.{table_name} (id, embedding, content, dataframe, blob_data, blob_meta, blob_mime_type, meta) VALUES (%(id)s, %(embedding)s, %(content)s, %(dataframe)s, %(blob_data)s, %(blob_meta)s, %(blob_mime_type)s, %(meta)s) """ @@ -54,7 +54,7 @@ KEYWORD_QUERY = """ SELECT {table_name}.*, ts_rank_cd(to_tsvector({language}, content), query) AS score -FROM {table_name}, plainto_tsquery({language}, %s) query +FROM {schema_name}.{table_name}, plainto_tsquery({language}, %s) query WHERE to_tsvector({language}, content) @@ query """ @@ -78,6 +78,7 @@ def __init__( self, *, connection_string: Secret = Secret.from_env_var("PG_CONN_STR"), + schema_name: str = "public", table_name: str = "haystack_documents", language: str = "english", embedding_dimension: int = 768, @@ -101,6 +102,7 @@ def __init__( e.g.: `PG_CONN_STR="host=HOST port=PORT dbname=DBNAME user=USER password=PASSWORD"` See [PostgreSQL Documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) for more details. + :param schema_name: The name of the schema the table is created in. The schema must already exist. :param table_name: The name of the table to use to store Haystack documents. :param language: The language to be used to parse query and document content in keyword retrieval. To see the list of available languages, you can run the following SQL query in your PostgreSQL database: @@ -137,6 +139,7 @@ def __init__( self.connection_string = connection_string self.table_name = table_name + self.schema_name = schema_name self.embedding_dimension = embedding_dimension if vector_function not in VALID_VECTOR_FUNCTIONS: msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}, but got {vector_function}" @@ -207,6 +210,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, connection_string=self.connection_string.to_dict(), + schema_name=self.schema_name, table_name=self.table_name, embedding_dimension=self.embedding_dimension, vector_function=self.vector_function, @@ -266,7 +270,9 @@ def _create_table_if_not_exists(self): """ create_sql = SQL(CREATE_TABLE_STATEMENT).format( - table_name=Identifier(self.table_name), embedding_dimension=SQLLiteral(self.embedding_dimension) + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + embedding_dimension=SQLLiteral(self.embedding_dimension), ) self._execute_sql(create_sql, error_msg="Could not create table in PgvectorDocumentStore") @@ -274,12 +280,18 @@ def _create_table_if_not_exists(self): def delete_table(self): """ Deletes the table used to store Haystack documents. - The name of the table (`table_name`) is defined when initializing the `PgvectorDocumentStore`. + The name of the schema (`schema_name`) and the name of the table (`table_name`) + are defined when initializing the `PgvectorDocumentStore`. """ + delete_sql = SQL("DROP TABLE IF EXISTS {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + ) - delete_sql = SQL("DROP TABLE IF EXISTS {table_name}").format(table_name=Identifier(self.table_name)) - - self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore") + self._execute_sql( + delete_sql, + error_msg=f"Could not delete table {self.schema_name}.{self.table_name} in PgvectorDocumentStore", + ) def _create_keyword_index_if_not_exists(self): """ @@ -287,15 +299,16 @@ def _create_keyword_index_if_not_exists(self): """ index_exists = bool( self._execute_sql( - "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", - (self.table_name, self.keyword_index_name), + "SELECT 1 FROM pg_indexes WHERE schemaname = %s AND tablename = %s AND indexname = %s", + (self.schema_name, self.table_name, self.keyword_index_name), "Could not check if keyword index exists", ).fetchone() ) sql_create_index = SQL( - "CREATE INDEX {index_name} ON {table_name} USING GIN (to_tsvector({language}, content))" + "CREATE INDEX {index_name} ON {schema_name}.{table_name} USING GIN (to_tsvector({language}, content))" ).format( + schema_name=Identifier(self.schema_name), index_name=Identifier(self.keyword_index_name), table_name=Identifier(self.table_name), language=SQLLiteral(self.language), @@ -318,8 +331,8 @@ def _handle_hnsw(self): index_exists = bool( self._execute_sql( - "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", - (self.table_name, self.hnsw_index_name), + "SELECT 1 FROM pg_indexes WHERE schemaname = %s AND tablename = %s AND indexname = %s", + (self.schema_name, self.table_name, self.hnsw_index_name), "Could not check if HNSW index exists", ).fetchone() ) @@ -349,8 +362,13 @@ def _create_hnsw_index(self): if key in HNSW_INDEX_CREATION_VALID_KWARGS } - sql_create_index = SQL("CREATE INDEX {index_name} ON {table_name} USING hnsw (embedding {ops}) ").format( - index_name=Identifier(self.hnsw_index_name), table_name=Identifier(self.table_name), ops=SQL(pg_ops) + sql_create_index = SQL( + "CREATE INDEX {index_name} ON {schema_name}.{table_name} USING hnsw (embedding {ops}) " + ).format( + schema_name=Identifier(self.schema_name), + index_name=Identifier(self.hnsw_index_name), + table_name=Identifier(self.table_name), + ops=SQL(pg_ops), ) if actual_hnsw_index_creation_kwargs: @@ -369,7 +387,9 @@ def count_documents(self) -> int: Returns how many documents are present in the document store. """ - sql_count = SQL("SELECT COUNT(*) FROM {table_name}").format(table_name=Identifier(self.table_name)) + sql_count = SQL("SELECT COUNT(*) FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) count = self._execute_sql(sql_count, error_msg="Could not count documents in PgvectorDocumentStore").fetchone()[ 0 @@ -395,7 +415,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." raise ValueError(msg) - sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) + sql_filter = SQL("SELECT * FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) params = () if filters: @@ -434,7 +456,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D db_documents = self._from_haystack_to_pg_documents(documents) - sql_insert = SQL(INSERT_STATEMENT).format(table_name=Identifier(self.table_name)) + sql_insert = SQL(INSERT_STATEMENT).format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) if policy == DuplicatePolicy.OVERWRITE: sql_insert += SQL(UPDATE_STATEMENT) @@ -543,8 +567,10 @@ def delete_documents(self, document_ids: List[str]) -> None: document_ids_str = ", ".join(f"'{document_id}'" for document_id in document_ids) - delete_sql = SQL("DELETE FROM {table_name} WHERE id IN ({document_ids_str})").format( - table_name=Identifier(self.table_name), document_ids_str=SQL(document_ids_str) + delete_sql = SQL("DELETE FROM {schema_name}.{table_name} WHERE id IN ({document_ids_str})").format( + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + document_ids_str=SQL(document_ids_str), ) self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore") @@ -570,6 +596,7 @@ def _keyword_retrieval( raise ValueError(msg) sql_select = SQL(KEYWORD_QUERY).format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name), language=SQLLiteral(self.language), query=SQLLiteral(query), @@ -643,7 +670,8 @@ def _embedding_retrieval( elif vector_function == "l2_distance": score_definition = f"embedding <-> {query_embedding_for_postgres} AS score" - sql_select = SQL("SELECT *, {score} FROM {table_name}").format( + sql_select = SQL("SELECT *, {score} FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name), score=SQL(score_definition), ) diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index 93514b71c..4af4fc8de 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -47,6 +47,7 @@ def test_init(monkeypatch): monkeypatch.setenv("PG_CONN_STR", "some_connection_string") document_store = PgvectorDocumentStore( + schema_name="my_schema", table_name="my_table", embedding_dimension=512, vector_function="l2_distance", @@ -59,6 +60,7 @@ def test_init(monkeypatch): keyword_index_name="my_keyword_index", ) + assert document_store.schema_name == "my_schema" assert document_store.table_name == "my_table" assert document_store.embedding_dimension == 512 assert document_store.vector_function == "l2_distance" @@ -93,6 +95,7 @@ def test_to_dict(monkeypatch): "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "my_table", + "schema_name": "public", "embedding_dimension": 512, "vector_function": "l2_distance", "recreate_table": True, diff --git a/integrations/pgvector/tests/test_retrievers.py b/integrations/pgvector/tests/test_retrievers.py index 290891307..4125c3e3a 100644 --- a/integrations/pgvector/tests/test_retrievers.py +++ b/integrations/pgvector/tests/test_retrievers.py @@ -50,6 +50,7 @@ def test_to_dict(self, mock_store): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "schema_name": "public", "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -175,6 +176,7 @@ def test_to_dict(self, mock_store): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "schema_name": "public", "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity",