diff --git a/integrations/astra/examples/example.py b/integrations/astra/examples/example.py index ac93f43ed..35963868c 100644 --- a/integrations/astra/examples/example.py +++ b/integrations/astra/examples/example.py @@ -8,10 +8,10 @@ from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter from haystack.components.routers import FileTypeRouter from haystack.components.writers import DocumentWriter -from haystack.document_stores import DuplicatePolicy +from haystack.document_stores.types import DuplicatePolicy -from astra_haystack.document_store import AstraDocumentStore -from astra_haystack.retriever import AstraRetriever +from haystack_integrations.components.retrievers.astra import AstraRetriever +from haystack_integrations.document_stores.astra import AstraDocumentStore logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/integrations/astra/examples/pipeline_example.py b/integrations/astra/examples/pipeline_example.py index fb13c3d93..cacb1eb9f 100644 --- a/integrations/astra/examples/pipeline_example.py +++ b/integrations/astra/examples/pipeline_example.py @@ -7,10 +7,10 @@ from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.generators import OpenAIGenerator from haystack.components.writers import DocumentWriter -from haystack.document_stores import DuplicatePolicy +from haystack.document_stores.types import DuplicatePolicy -from astra_haystack.document_store import AstraDocumentStore -from astra_haystack.retriever import AstraRetriever +from haystack_integrations.components.retrievers.astra import AstraRetriever +from haystack_integrations.document_stores.astra import AstraDocumentStore logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index b99449e03..6b4e2565d 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -35,6 +35,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/astra" +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + [tool.hatch.version] source = "vcs" tag-pattern = 'integrations\/astra-v(?P.*)' @@ -71,7 +74,7 @@ dependencies = [ "ruff>=0.0.243", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/astra_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -141,17 +144,17 @@ unfixable = [ exclude = ["example"] [tool.ruff.isort] -known-first-party = ["astra_haystack"] +known-first-party = ["haystack_integrations"] [tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" +ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["astra_haystack", "tests"] +source_pkgs = ["haystack_integrations", "tests"] branch = true parallel = true omit = [ @@ -159,7 +162,7 @@ omit = [ ] [tool.coverage.paths] -astra_haystack = ["src/astra_haystack", "*/astra-store/src/astra_haystack"] +astra_haystack = ["src"] tests = ["tests"] [tool.coverage.report] @@ -178,10 +181,10 @@ markers = [ [[tool.mypy.overrides]] module = [ - "astra_haystack.*", "astra_client.*", "pydantic.*", "haystack.*", + "haystack_integrations.*", "pytest.*" ] ignore_missing_imports = true diff --git a/integrations/astra/src/haystack_integrations/components/retrievers/astra/__init__.py b/integrations/astra/src/haystack_integrations/components/retrievers/astra/__init__.py new file mode 100644 index 000000000..33ef6d15e --- /dev/null +++ b/integrations/astra/src/haystack_integrations/components/retrievers/astra/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 +from .retriever import AstraRetriever + +__all__ = ["AstraRetriever"] diff --git a/integrations/astra/src/astra_haystack/retriever.py b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py similarity index 96% rename from integrations/astra/src/astra_haystack/retriever.py rename to integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py index 47304df2c..fdf9b0722 100644 --- a/integrations/astra/src/astra_haystack/retriever.py +++ b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py @@ -6,7 +6,7 @@ from haystack import Document, component, default_from_dict, default_to_dict -from astra_haystack.document_store import AstraDocumentStore +from haystack_integrations.document_stores.astra import AstraDocumentStore @component diff --git a/integrations/astra/src/astra_haystack/__init__.py b/integrations/astra/src/haystack_integrations/document_stores/astra/__init__.py similarity index 71% rename from integrations/astra/src/astra_haystack/__init__.py rename to integrations/astra/src/haystack_integrations/document_stores/astra/__init__.py index 5c99dedf6..4618beb08 100644 --- a/integrations/astra/src/astra_haystack/__init__.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2023-present Anant Corporation # # SPDX-License-Identifier: Apache-2.0 -from astra_haystack.document_store import AstraDocumentStore +from .document_store import AstraDocumentStore __all__ = ["AstraDocumentStore"] diff --git a/integrations/astra/src/astra_haystack/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py similarity index 100% rename from integrations/astra/src/astra_haystack/astra_client.py rename to integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py diff --git a/integrations/astra/src/astra_haystack/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py similarity index 91% rename from integrations/astra/src/astra_haystack/document_store.py rename to integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index 6e630bef5..8e03de4a6 100644 --- a/integrations/astra/src/astra_haystack/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -12,9 +12,9 @@ from haystack.document_stores.errors import DuplicateDocumentError, MissingDocumentError from haystack.document_stores.types import DuplicatePolicy -from astra_haystack.astra_client import AstraClient -from astra_haystack.errors import AstraDocumentStoreFilterError -from astra_haystack.filters import _convert_filters +from .astra_client import AstraClient +from .errors import AstraDocumentStoreFilterError +from .filters import _convert_filters logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def __init__( astra_application_token: str, astra_keyspace: str, astra_collection: str, - embedding_dim: Optional[int] = 768, + embedding_dim: int = 768, duplicates_policy: DuplicatePolicy = DuplicatePolicy.NONE, similarity: str = "cosine", ): @@ -104,17 +104,12 @@ def to_dict(self) -> Dict[str, Any]: def write_documents( self, documents: List[Document], - index: Optional[str] = None, - batch_size: int = 20, policy: DuplicatePolicy = DuplicatePolicy.NONE, ): """ Indexes documents for later queries. :param documents: a list of Haystack Document objects. - :param index: Optional name of index where the documents shall be written to. - If None, the DocumentStore's default index (self.index) will be used. - :param batch_size: Number of documents that are passed to bulk function at a time. :param policy: Handle duplicate documents based on DuplicatePolicy parameter options. Parameter options : (SKIP, OVERWRITE, FAIL, NONE) - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, @@ -125,26 +120,13 @@ def write_documents( - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. :return: int """ - - if index is None and self.index is None: - msg = "No Astra client provided" - raise ValueError(msg) - - if index is None: - index = self.index - if policy is None or policy == DuplicatePolicy.NONE: if self.duplicates_policy is not None and self.duplicates_policy != DuplicatePolicy.NONE: policy = self.duplicates_policy else: policy = DuplicatePolicy.SKIP - if batch_size > MAX_BATCH_SIZE: - logger.warning( - f"batch_size set to {batch_size}, " - f"but maximum batch_size for Astra when using the JSON API is 20. batch_size set to 20." - ) - batch_size = MAX_BATCH_SIZE + batch_size = MAX_BATCH_SIZE def _convert_input_document(document: Union[dict, Document]): if isinstance(document, Document): @@ -196,7 +178,7 @@ def _convert_input_document(document: Union[dict, Document]): if policy == DuplicatePolicy.SKIP: if len(new_documents) > 0: for batch in _batches(new_documents, batch_size): - inserted_ids = index.insert(batch) # type: ignore + inserted_ids = self.index.insert(batch) # type: ignore insertion_counter += len(inserted_ids) logger.info(f"write_documents inserted documents with id {inserted_ids}") else: @@ -205,7 +187,7 @@ def _convert_input_document(document: Union[dict, Document]): elif policy == DuplicatePolicy.OVERWRITE: if len(new_documents) > 0: for batch in _batches(new_documents, batch_size): - inserted_ids = index.insert(batch) # type: ignore + inserted_ids = self.index.insert(batch) # type: ignore insertion_counter += len(inserted_ids) logger.info(f"write_documents inserted documents with id {inserted_ids}") else: @@ -214,7 +196,7 @@ def _convert_input_document(document: Union[dict, Document]): if len(duplicate_documents) > 0: updated_ids = [] for duplicate_doc in duplicate_documents: - updated = index.update_document(duplicate_doc, "_id") # type: ignore + updated = self.index.update_document(duplicate_doc, "_id") # type: ignore if updated: updated_ids.append(duplicate_doc["_id"]) insertion_counter = insertion_counter + len(updated_ids) @@ -225,7 +207,7 @@ def _convert_input_document(document: Union[dict, Document]): elif policy == DuplicatePolicy.FAIL: if len(new_documents) > 0: for batch in _batches(new_documents, batch_size): - inserted_ids = index.insert(batch) # type: ignore + inserted_ids = self.index.insert(batch) # type: ignore insertion_counter = insertion_counter + len(inserted_ids) logger.info(f"write_documents inserted documents with id {inserted_ids}") else: diff --git a/integrations/astra/src/astra_haystack/errors.py b/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py similarity index 100% rename from integrations/astra/src/astra_haystack/errors.py rename to integrations/astra/src/haystack_integrations/document_stores/astra/errors.py diff --git a/integrations/astra/src/astra_haystack/filters.py b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py similarity index 100% rename from integrations/astra/src/astra_haystack/filters.py rename to integrations/astra/src/haystack_integrations/document_stores/astra/filters.py diff --git a/integrations/astra/tests/conftest.py b/integrations/astra/tests/conftest.py index 02f5d7cad..274b38352 100644 --- a/integrations/astra/tests/conftest.py +++ b/integrations/astra/tests/conftest.py @@ -3,7 +3,7 @@ import pytest from haystack.document_stores.types import DuplicatePolicy -from astra_haystack.document_store import AstraDocumentStore +from haystack_integrations.document_stores.astra import AstraDocumentStore @pytest.fixture diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index f203ab721..019a66398 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -10,7 +10,7 @@ from haystack.document_stores.types import DuplicatePolicy from haystack.testing.document_store import DocumentStoreBaseTests -from astra_haystack.document_store import AstraDocumentStore +from haystack_integrations.document_stores.astra import AstraDocumentStore @pytest.mark.skipif( diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py index 2212d44fd..eb9260590 100644 --- a/integrations/astra/tests/test_retriever.py +++ b/integrations/astra/tests/test_retriever.py @@ -5,7 +5,7 @@ import pytest -from astra_haystack.retriever import AstraRetriever +from haystack_integrations.components.retrievers.astra import AstraRetriever @pytest.mark.skipif( @@ -16,7 +16,7 @@ def test_retriever_to_json(document_store): retriever = AstraRetriever(document_store, filters={"foo": "bar"}, top_k=99) assert retriever.to_dict() == { - "type": "astra_haystack.retriever.AstraRetriever", + "type": "haystack_integrations.components.retrievers.astra.retriever.AstraRetriever", "init_parameters": { "filters": {"foo": "bar"}, "top_k": 99, @@ -30,7 +30,7 @@ def test_retriever_to_json(document_store): "embedding_dim": 768, "similarity": "cosine", }, - "type": "astra_haystack.document_store.AstraDocumentStore", + "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", }, }, } @@ -43,7 +43,7 @@ def test_retriever_to_json(document_store): @pytest.mark.integration def test_retriever_from_json(): data = { - "type": "astra_haystack.retriever.AstraRetriever", + "type": "haystack_integrations.components.retrievers.astra.retriever.AstraRetriever", "init_parameters": { "filters": {"bar": "baz"}, "top_k": 42, @@ -58,7 +58,7 @@ def test_retriever_from_json(): "embedding_dim": 768, "similarity": "cosine", }, - "type": "astra_haystack.document_store.AstraDocumentStore", + "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", }, }, }