Skip to content

Commit

Permalink
feat: add create_extension parameter to control vector extension cr…
Browse files Browse the repository at this point in the history
…eation (#1213)

* Create feature called create_extension and updated documentation.

* small refinements

* update docker image

---------

Co-authored-by: anakin87 <[email protected]>
  • Loading branch information
wuqunfei and anakin87 authored Nov 22, 2024
1 parent de32fa3 commit f286bdf
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pgvector.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
python-version: ["3.9", "3.10", "3.11"]
services:
pgvector:
image: ankane/pgvector:latest
image: pgvector/pgvector:pg17
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
self,
*,
connection_string: Secret = Secret.from_env_var("PG_CONN_STR"),
create_extension: bool = True,
schema_name: str = "public",
table_name: str = "haystack_documents",
language: str = "english",
Expand All @@ -102,6 +103,10 @@ def __init__(
e.g.: `PG_CONN_STR="host=HOST port=PORT dbname=DBNAME user=USER password=PASSWORD"`
See [PostgreSQL Documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING)
for more details.
:param create_extension: Whether to create the pgvector extension if it doesn't exist.
Set this to `True` (default) to automatically create the extension if it is missing.
Creating the extension may require superuser privileges.
If set to `False`, ensure the extension is already installed; otherwise, an error will be raised.
:param schema_name: The name of the schema the table is created in. The schema must already exist.
:param table_name: The name of the table to use to store Haystack documents.
:param language: The language to be used to parse query and document content in keyword retrieval.
Expand Down Expand Up @@ -138,6 +143,7 @@ def __init__(
"""

self.connection_string = connection_string
self.create_extension = create_extension
self.table_name = table_name
self.schema_name = schema_name
self.embedding_dimension = embedding_dimension
Expand Down Expand Up @@ -194,7 +200,8 @@ def _create_connection(self):
conn_str = self.connection_string.resolve_value() or ""
connection = connect(conn_str)
connection.autocommit = True
connection.execute("CREATE EXTENSION IF NOT EXISTS vector")
if self.create_extension:
connection.execute("CREATE EXTENSION IF NOT EXISTS vector")
register_vector(connection) # Note: this must be called before creating the cursors.

self._connection = connection
Expand Down Expand Up @@ -246,6 +253,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
connection_string=self.connection_string.to_dict(),
create_extension=self.create_extension,
schema_name=self.schema_name,
table_name=self.table_name,
embedding_dimension=self.embedding_dimension,
Expand Down
4 changes: 4 additions & 0 deletions integrations/pgvector/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def test_init(monkeypatch):
monkeypatch.setenv("PG_CONN_STR", "some_connection_string")

document_store = PgvectorDocumentStore(
create_extension=True,
schema_name="my_schema",
table_name="my_table",
embedding_dimension=512,
Expand All @@ -79,6 +80,7 @@ def test_init(monkeypatch):
keyword_index_name="my_keyword_index",
)

assert document_store.create_extension
assert document_store.schema_name == "my_schema"
assert document_store.table_name == "my_table"
assert document_store.embedding_dimension == 512
Expand All @@ -97,6 +99,7 @@ def test_to_dict(monkeypatch):
monkeypatch.setenv("PG_CONN_STR", "some_connection_string")

document_store = PgvectorDocumentStore(
create_extension=False,
table_name="my_table",
embedding_dimension=512,
vector_function="l2_distance",
Expand All @@ -113,6 +116,7 @@ def test_to_dict(monkeypatch):
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore",
"init_parameters": {
"connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"},
"create_extension": False,
"table_name": "my_table",
"schema_name": "public",
"embedding_dimension": 512,
Expand Down
6 changes: 6 additions & 0 deletions integrations/pgvector/tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_to_dict(self, mock_store):
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore",
"init_parameters": {
"connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"},
"create_extension": True,
"schema_name": "public",
"table_name": "haystack",
"embedding_dimension": 768,
Expand Down Expand Up @@ -82,6 +83,7 @@ def test_from_dict(self, monkeypatch):
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore",
"init_parameters": {
"connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"},
"create_extension": False,
"table_name": "haystack_test_to_dict",
"embedding_dimension": 768,
"vector_function": "cosine_similarity",
Expand All @@ -106,6 +108,7 @@ def test_from_dict(self, monkeypatch):

assert isinstance(document_store, PgvectorDocumentStore)
assert isinstance(document_store.connection_string, EnvVarSecret)
assert not document_store.create_extension
assert document_store.table_name == "haystack_test_to_dict"
assert document_store.embedding_dimension == 768
assert document_store.vector_function == "cosine_similarity"
Expand Down Expand Up @@ -176,6 +179,7 @@ def test_to_dict(self, mock_store):
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore",
"init_parameters": {
"connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"},
"create_extension": True,
"schema_name": "public",
"table_name": "haystack",
"embedding_dimension": 768,
Expand Down Expand Up @@ -207,6 +211,7 @@ def test_from_dict(self, monkeypatch):
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore",
"init_parameters": {
"connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"},
"create_extension": False,
"table_name": "haystack_test_to_dict",
"embedding_dimension": 768,
"vector_function": "cosine_similarity",
Expand All @@ -230,6 +235,7 @@ def test_from_dict(self, monkeypatch):

assert isinstance(document_store, PgvectorDocumentStore)
assert isinstance(document_store.connection_string, EnvVarSecret)
assert not document_store.create_extension
assert document_store.table_name == "haystack_test_to_dict"
assert document_store.embedding_dimension == 768
assert document_store.vector_function == "cosine_similarity"
Expand Down

0 comments on commit f286bdf

Please sign in to comment.