Skip to content

Commit

Permalink
Linter fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mounaTay committed Jan 9, 2024
1 parent 3ed35f6 commit 7c6672a
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 80 deletions.
5 changes: 5 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ integration:amazon-bedrock:
- any-glob-to-any-file: "integrations/amazon_bedrock/**/*"
- any-glob-to-any-file: ".github/workflows/amazon_bedrock.yml"

integration:astra:
- changed-files:
- any-glob-to-any-file: "integrations/astra/**/*"
- any-glob-to-any-file: ".github/workflows/astra.yml"

integration:chroma:
- changed-files:
- any-glob-to-any-file: "integrations/chroma/**/*"
Expand Down
51 changes: 32 additions & 19 deletions integrations/astra/examples/example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from pathlib import Path

Expand All @@ -12,9 +13,13 @@
from astra_haystack.document_store import AstraDocumentStore
from astra_haystack.retriever import AstraRetriever

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


HERE = Path(__file__).resolve().parent
file_paths = [HERE / "data" / Path(name) for name in os.listdir("integrations/astra/examples/data")]
print(file_paths)
logger.info(file_paths)

astra_id = os.getenv("ASTRA_DB_ID", "")
astra_region = os.getenv("ASTRA_DB_REGION", "us-east1")
Expand Down Expand Up @@ -67,15 +72,18 @@

question = "This chapter introduces the manuals available with Vim"
result = q.run({"embedder": {"text": question}, "retriever": {"top_k": 1}})
print(result)

print("count:")
print(document_store.count_documents())
assert document_store.count_documents() == 9

print("filter:")
print(
document_store.filter_documents(
logger.info(result)

ALL_DOCUMENTS_COUNT = 9
documents_count = document_store.count_documents()
logger.info("count:")
logger.info(documents_count)
if documents_count != ALL_DOCUMENTS_COUNT:
msg = f"count mismatch, expected 9 documents, got {documents_count}"
raise ValueError(msg)

logger.info(
f"""filter results: {document_store.filter_documents(
{
"field": "meta",
"operator": "==",
Expand All @@ -85,22 +93,27 @@
},
}
)
}"""
)

print("get_document_by_id")
print(document_store.get_document_by_id("92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10"))
print("get_documents_by_ids")
print(
document_store.get_documents_by_id(
logger.info(
f"""get_document_by_id {document_store.get_document_by_id(
"92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10")}"""
)

logger.info(
f"""get_documents_by_ids {document_store.get_documents_by_id(
[
"92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10",
"6f2450a51eaa3eeb9239d875402bcfe24b2d3534ff27f26c1f3fc8133b04e756",
]
)
)}"""
)

document_store.delete_documents(["92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10"])

print("count:")
print(document_store.count_documents())
assert document_store.count_documents() == 8
documents_count = document_store.count_documents()
logger.info(f"count: {document_store.count_documents()}")
if documents_count != ALL_DOCUMENTS_COUNT - 1:
msg = f"count mismatch, expected 9 documents, got {documents_count}"
raise ValueError(msg)
13 changes: 10 additions & 3 deletions integrations/astra/examples/pipeline_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os

from haystack import Document, Pipeline
Expand All @@ -11,6 +12,9 @@
from astra_haystack.document_store import AstraDocumentStore
from astra_haystack.retriever import AstraRetriever

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# Create a RAG query pipeline
prompt_template = """
Given these documents, answer the question.
Expand Down Expand Up @@ -48,10 +52,12 @@
documents = [
Document(content="There are over 7,000 languages spoken around the world today."),
Document(
content="Elephants have been observed to behave in a way that indicates a high level of self-awareness, such as recognizing themselves in mirrors."
content="Elephants have been observed to behave in a way that indicates"
" a high level of self-awareness, such as recognizing themselves in mirrors."
),
Document(
content="In certain parts of the world, like the Maldives, Puerto Rico, and San Diego, you can witness the phenomenon of bioluminescent waves."
content="In certain parts of the world, like the Maldives, Puerto Rico, "
"and San Diego, you can witness the phenomenon of bioluminescent waves."
),
]
p = Pipeline()
Expand Down Expand Up @@ -97,4 +103,5 @@
"answer_builder": {"query": question},
}
)
print(result)

logger.info(result)
9 changes: 4 additions & 5 deletions integrations/astra/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@ target-version = ["py37"]
line-length = 120
skip-string-normalization = true

[tool.isort]
profile = 'black'
line_length = 79
skip_gitignore = true

[tool.ruff]
target-version = "py37"
line-length = 120
Expand Down Expand Up @@ -142,6 +137,7 @@ unfixable = [
# Don't touch unused imports
"F401",
]
exclude = ["example"]

[tool.ruff.isort]
known-first-party = ["astra_haystack"]
Expand All @@ -157,6 +153,9 @@ ban-relative-imports = "all"
source_pkgs = ["astra_haystack", "tests"]
branch = true
parallel = true
omit = [
"example"
]

[tool.coverage.paths]
astra_haystack = ["src/astra_haystack", "*/astra-store/src/astra_haystack"]
Expand Down
29 changes: 15 additions & 14 deletions integrations/astra/src/astra_haystack/astra_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@dataclass
class Response:
id: str
document_id: str
text: Optional[str]
values: Optional[list]
metadata: Optional[dict]
Expand Down Expand Up @@ -80,13 +80,15 @@ def find_index(self):

collection_embedding_dim = collection_name_matches[0]["options"]["vector"]["dimension"]
if collection_embedding_dim != self.embedding_dim:
raise Exception(
msg = (
f"Collection vector dimension is not valid, expected {self.embedding_dim}, "
f"found {collection_embedding_dim}"
)
raise Exception(msg)

else:
raise Exception(f"status not in response: {response.text}")
msg = f"status not in response: {response.text}"
raise Exception(msg)

return True

Expand All @@ -107,9 +109,8 @@ def create_index(self):
def query(
self,
vector: Optional[List[float]] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
query_filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
top_k: Optional[int] = None,
namespace: Optional[str] = None,
include_metadata: Optional[bool] = None,
include_values: Optional[bool] = None,
) -> QueryResponse:
Expand All @@ -122,7 +123,7 @@ def query(
being queried. Each `query()` request can contain only one of the parameters
`queries`, `id` or `vector`... [optional]
top_k (int): The number of results to return for each query. Must be an integer greater than 1.
filter (Dict[str, Union[str, float, int, bool, List, dict]):
query_filter (Dict[str, Union[str, float, int, bool, List, dict]):
The filter to apply. You can use vector metadata to limit your search. [optional]
include_metadata (bool): Indicates whether metadata is included in the response as well as the ids.
If omitted the server will use the default value of False [optional]
Expand All @@ -134,9 +135,9 @@ def query(
"""
# get vector data and scores
if vector is None:
responses = self._query_without_vector(top_k, filter)
responses = self._query_without_vector(top_k, query_filter)
else:
responses = self._query(vector, top_k, filter)
responses = self._query(vector, top_k, query_filter)

# include_metadata means return all columns in the table (including text that got embedded)
# include_values means return the vector of the embedding for the searched items
Expand All @@ -158,7 +159,7 @@ def _format_query_response(responses, include_metadata, include_values):
score = response.pop("$similarity", None)
text = response.pop("content", None)
values = response.pop("$vector", None) if include_values else []
metadata = response if include_metadata else dict() # Add all remaining fields to the metadata
metadata = response if include_metadata else {} # Add all remaining fields to the metadata
rsp = Response(_id, text, values, metadata, score)
final_res.append(rsp)
return QueryResponse(final_res)
Expand All @@ -185,7 +186,7 @@ def find_documents(self, find_query):
if "data" in response_dict and "documents" in response_dict["data"]:
return response_dict["data"]["documents"]
else:
logger.warning("No documents found", response_dict)
logger.warning(f"No documents found: {response_dict}")

def get_documents(self, ids: List[str], batch_size: int = 20) -> QueryResponse:
document_batch = []
Expand Down Expand Up @@ -253,14 +254,14 @@ def delete(
self,
ids: Optional[List[str]] = None,
delete_all: Optional[bool] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
filters: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
) -> int:
if delete_all:
query = {"deleteMany": {}}
query = {"deleteMany": {}} # type: dict
if ids is not None:
query = {"deleteMany": {"filter": {"_id": {"$in": ids}}}}
if filter is not None:
query = {"deleteMany": {"filter": filter}}
if filters is not None:
query = {"deleteMany": {"filter": filters}}

deletion_counter = 0
moredata = True
Expand Down
Loading

0 comments on commit 7c6672a

Please sign in to comment.