Skip to content

Commit

Permalink
fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Feb 1, 2024
1 parent a999960 commit 2acc0d6
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 96 deletions.
2 changes: 1 addition & 1 deletion integrations/supabase/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ unfixable = [
known-first-party = ["supabase_haystack"]

[tool.ruff.flake8-tidy-imports]
ban-relative-imports = "all"
ban-relative-imports = "parents"

[tool.ruff.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@ class SupabaseEmbeddingRetriever:
"""

def __init__(
self,
*,
document_store: SupabaseDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10):
self, *, document_store: SupabaseDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10
):
"""
Create an SupabaseEmbeddingRetriever component.
Expand All @@ -35,7 +32,6 @@ def __init__(
msg = "document_store must be an instance of SupabaseDocumentStore"
raise ValueError(msg)


self.document_store = document_store
self.filters = filters or {}
self.top_k = top_k
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,21 @@

class SupabaseDocumentStore:
def __init__(
self,
host:str,
password:str,
user:str = "postgres",
port:str = "5432",
db_name:str = "postgres",
collection_name:str = "documents",
dimension:int = 768,
**collection_creation_kwargs,
self,
host: str,
password: str,
user: str = "postgres",
port: str = "5432",
db_name: str = "postgres",
collection_name: str = "documents",
dimension: int = 768,
**collection_creation_kwargs,
):
"""
Creates a new SupabaseDocumentStore instance.
For more information on connection parameters, see the official supabase vector documentation: https://supabase.github.io/vecs/0.4/
For more information on connection parameters, see the official supabase vector documentation:
https://supabase.github.io/vecs/0.4/
:param user: The username for connecting to the Supabase database.
:param password: The password for connecting to the Supabase database.
Expand All @@ -43,16 +44,17 @@ def __init__(
"""
self.dimension = dimension
self._collection_name = collection_name
self._dummy_vector = [0.0]*dimension
self._dummy_vector = [0.0] * dimension
self.collection_creation_kwargs = collection_creation_kwargs
db_connection = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
self._pgvector_client = vecs.create_client(db_connection)
self._collection = self._pgvector_client.get_or_create_collection(name=collection_name, dimension=dimension, **collection_creation_kwargs)
self._collection = self._pgvector_client.get_or_create_collection(
name=collection_name, dimension=dimension, **collection_creation_kwargs
)
self._adapter = None
if collection_creation_kwargs.get("adapter") is not None:
self._adapter = collection_creation_kwargs["adapter"]


def count_documents(self) -> int:
"""
Returns how many documents are present in the document store.
Expand Down Expand Up @@ -125,11 +127,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
"""
# pgvector store performs vector similarity search
# here we are querying with a dummy vector and the max compatible top_k
documents = self._embedding_retrieval(
query_embedding=self._dummy_vector,
filters=filters,
top_k=10
)
documents = self._embedding_retrieval(query_embedding=self._dummy_vector, filters=filters, top_k=10)

return documents

Expand All @@ -141,7 +139,8 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
:param policy: The duplicate policy to use when writing documents.
SupabaseDocumentStore only supports `DuplicatePolicy.OVERWRITE`.
:return: This figure may lack accuracy as it fails to distinguish between documents that were genuinely written and those that were not(overwritten).
:return: This figure may lack accuracy as it fails to distinguish between documents that were genuinely
written and those that were not(overwritten).
"""
if policy not in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]:
logger.warning(
Expand All @@ -161,16 +160,16 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
"'array', 'dataframe' and 'blob' will be dropped."
)
if self._adapter is not None:
documents_for_supabase.append((doc.id, doc.content, {"content":doc.content, **doc.meta}))
documents_for_supabase.append((doc.id, doc.content, {"content": doc.content, **doc.meta}))
else:
embedding = doc.embedding
if doc.embedding is None:
logger.warning(
f"Document {doc.id} has no embedding. pgvector is a purely vector database. "
"A dummy embedding will be used, but this can affect the search results. "
)
f"Document {doc.id} has no embedding. pgvector is a purely vector database. "
"A dummy embedding will be used, but this can affect the search results. "
)
embedding = self._dummy_vector
documents_for_supabase.append((doc.id, embedding, {"content":doc.content, **doc.meta}))
documents_for_supabase.append((doc.id, embedding, {"content": doc.content, **doc.meta}))

self._collection.upsert(records=documents_for_supabase)
self._collection.create_index()
Expand All @@ -191,7 +190,7 @@ def _convert_query_result_to_documents(self, result) -> List[Document]:
documents = []
for i in result:
supabase_data = self._collection.__getitem__(i)
document_dict = {"id":supabase_data[0]}
document_dict = {"id": supabase_data[0]}
document_dict["embedding"] = np.array(supabase_data[1])
metadata = supabase_data[2]
document_dict["content"] = metadata["content"]
Expand All @@ -202,11 +201,7 @@ def _convert_query_result_to_documents(self, result) -> List[Document]:
return documents

def _embedding_retrieval(
self,
query_embedding: List[float],
*,
filters: Optional[Dict[str, Any]],
top_k: int = 10
self, query_embedding: List[float], *, filters: Optional[Dict[str, Any]], top_k: int = 10
) -> List[Document]:
"""
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
Expand All @@ -225,10 +220,5 @@ def _embedding_retrieval(
filters = convert(filters)
filters = _normalize_filters(filters) if filters else None

results = self._collection.query(
data=query_embedding,
limit=top_k,
filters=filters
)
results = self._collection.query(data=query_embedding, limit=top_k, filters=filters)
return self._convert_query_result_to_documents(result=results)

Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _in(field: str, value: Any) -> Dict[str, Any]:
">=": _greater_than_equal,
"<": _less_than,
"<=": _less_than_equal,
"in": _in
"in": _in,
}

LOGICAL_OPERATORS = {"AND": "$and", "OR": "$or"}
10 changes: 3 additions & 7 deletions integrations/supabase/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,24 @@ def document_store():
password = "SmMt8z7naxDhibsC"
host = "db.ynylwqvxjhidyomfvwou.supabase.co"

store = SupabaseDocumentStore(
password=password,
host=host
)
store = SupabaseDocumentStore(password=password, host=host)

# Override some methods to wait for the documents to be available
original_write_documents = store.write_documents

def write_documents_and_wait(documents, policy=DuplicatePolicy.NONE):
written_docs = original_write_documents(documents, policy)
#time.sleep(SLEEP_TIME)
# time.sleep(SLEEP_TIME)
return written_docs

original_delete_documents = store.delete_documents

def delete_documents_and_wait(filters):
original_delete_documents(filters)
#time.sleep(SLEEP_TIME)
# time.sleep(SLEEP_TIME)

store.write_documents = write_documents_and_wait
store.delete_documents = delete_documents_and_wait

yield store
store._collection.delete(ids=store._collection.query(data=store._dummy_vector, limit=store.count_documents()))

20 changes: 7 additions & 13 deletions integrations/supabase/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,11 @@ def test_embedding_retrieval(self, document_store: SupabaseDocumentStore):
@patch("src.haystack_integrations.document_stores.supabase.document_store.vecs")
def test_init(self, mock_supabase):

document_store = SupabaseDocumentStore(
host="fake-host",
password="password",
dimension=30
)
document_store = SupabaseDocumentStore(host="fake-host", password="password", dimension=30)

user:str = "postgres"
port:str = "5432"
db_name:str = "postgres"
user: str = "postgres"
port: str = "5432"
db_name: str = "postgres"

db_connection = f"postgresql://{user}:password@fake-host:{port}/{db_name}"
mock_supabase.create_client.assert_called_with(db_connection)
Expand All @@ -98,20 +94,18 @@ def test_init(self, mock_supabase):
assert document_store.dimension == 30

@pytest.mark.skip(reason="Supabase only supports UPSERT operations")
def test_write_documents_duplicate_fail(self, document_store: SupabaseDocumentStore):
...
def test_write_documents_duplicate_fail(self, document_store: SupabaseDocumentStore): ...

@pytest.mark.skip(reason="Supabase only supports UPSERT operations")
def test_write_documents_duplicate_skip(self, document_store: SupabaseDocumentStore):
...
def test_write_documents_duplicate_skip(self, document_store: SupabaseDocumentStore): ...

def test_write_documents_duplicate_overwrite(self, document_store: SupabaseDocumentStore):
"""
Test write_documents() overwrites stored Document when trying to write one with same id
using DuplicatePolicy.OVERWRITE.
"""
embedding = [0.0] * 768
doc1 = Document(id="1", content="test doc 1", embedding=[0.1]*768)
doc1 = Document(id="1", content="test doc 1", embedding=[0.1] * 768)
doc2 = Document(id="1", content="test doc 2", embedding=embedding)

assert document_store.write_documents([doc2], policy=DuplicatePolicy.OVERWRITE) == 1
Expand Down
48 changes: 16 additions & 32 deletions integrations/supabase/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,68 +34,52 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do
assert received_doc.embedding == pytest.approx(expected_doc.embedding)

@pytest.mark.skip(reason="Supabase does not support not_in comparison")
def test_comparison_not_in(self, document_store, filterable_docs):
...
def test_comparison_not_in(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support not_in comparison")
def test_comparison_not_in_with_with_non_list_iterable(self, document_store, filterable_docs):
...
def test_comparison_not_in_with_with_non_list_iterable(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support not_in comparison")
def test_comparison_not_in_with_with_non_list(self, document_store, filterable_docs):
...
def test_comparison_not_in_with_with_non_list(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support comparison with null values")
def test_comparison_equal_with_none(self, document_store, filterable_docs):
...
def test_comparison_equal_with_none(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support not_in comparison")
def test_comparison_not_equal_with_none(self, document_store, filterable_docs):
...
def test_comparison_not_equal_with_none(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support comparison with null values")
def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs):
...
def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support comparison with null values")
def test_comparison_greater_than_with_none(self, document_store, filterable_docs):
...
def test_comparison_greater_than_with_none(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support comparison with null values")
def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs):
...
def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support comparison with null values")
def test_comparison_less_than_with_none(self, document_store, filterable_docs):
...
def test_comparison_less_than_with_none(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support the 'not' operator")
def test_not_operator(self, document_store, filterable_docs):
...
def test_not_operator(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support comparison with dates")
def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs):
...
def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support comparison with dates")
def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs):
...
def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support comparison with dates")
def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs):
...
def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase does not support comparison with dates")
def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs):
...
def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase doesn't support comparision with dataframe")
def test_comparison_equal_with_dataframe(self, document_store, filterable_docs):
...
def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Supabase doesn't support comparision with dataframe")
def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs):
...
def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): ...

def test_or_operator(self, document_store, filterable_docs):
document_store.write_documents(filterable_docs)
Expand Down

0 comments on commit 2acc0d6

Please sign in to comment.