Skip to content

Commit

Permalink
refact!: change import paths (#277)
Browse files Browse the repository at this point in the history
* change import paths

* linting

* fix protocol interface

* fix coverage

* moar linting
  • Loading branch information
masci authored Jan 29, 2024
1 parent 37507de commit df86747
Show file tree
Hide file tree
Showing 13 changed files with 39 additions and 48 deletions.
6 changes: 3 additions & 3 deletions integrations/astra/examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter
from haystack.components.routers import FileTypeRouter
from haystack.components.writers import DocumentWriter
from haystack.document_stores import DuplicatePolicy
from haystack.document_stores.types import DuplicatePolicy

from astra_haystack.document_store import AstraDocumentStore
from astra_haystack.retriever import AstraRetriever
from haystack_integrations.components.retrievers.astra import AstraRetriever
from haystack_integrations.document_stores.astra import AstraDocumentStore

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down
6 changes: 3 additions & 3 deletions integrations/astra/examples/pipeline_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
from haystack.components.generators import OpenAIGenerator
from haystack.components.writers import DocumentWriter
from haystack.document_stores import DuplicatePolicy
from haystack.document_stores.types import DuplicatePolicy

from astra_haystack.document_store import AstraDocumentStore
from astra_haystack.retriever import AstraRetriever
from haystack_integrations.components.retrievers.astra import AstraRetriever
from haystack_integrations.document_stores.astra import AstraDocumentStore

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down
15 changes: 9 additions & 6 deletions integrations/astra/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m
Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues"
Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/astra"

[tool.hatch.build.targets.wheel]
packages = ["src/haystack_integrations"]

[tool.hatch.version]
source = "vcs"
tag-pattern = 'integrations\/astra-v(?P<version>.*)'
Expand Down Expand Up @@ -71,7 +74,7 @@ dependencies = [
"ruff>=0.0.243",
]
[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive {args:src/astra_haystack tests}"
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
style = [
"ruff {args:.}",
"black --check --diff {args:.}",
Expand Down Expand Up @@ -141,25 +144,25 @@ unfixable = [
exclude = ["example"]

[tool.ruff.isort]
known-first-party = ["astra_haystack"]
known-first-party = ["haystack_integrations"]

[tool.ruff.flake8-tidy-imports]
ban-relative-imports = "all"
ban-relative-imports = "parents"

[tool.ruff.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]

[tool.coverage.run]
source_pkgs = ["astra_haystack", "tests"]
source_pkgs = ["haystack_integrations", "tests"]
branch = true
parallel = true
omit = [
"example"
]

[tool.coverage.paths]
astra_haystack = ["src/astra_haystack", "*/astra-store/src/astra_haystack"]
astra_haystack = ["src"]
tests = ["tests"]

[tool.coverage.report]
Expand All @@ -178,10 +181,10 @@ markers = [

[[tool.mypy.overrides]]
module = [
"astra_haystack.*",
"astra_client.*",
"pydantic.*",
"haystack.*",
"haystack_integrations.*",
"pytest.*"
]
ignore_missing_imports = true
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: 2023-present Anant Corporation <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .retriever import AstraRetriever

__all__ = ["AstraRetriever"]
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from haystack import Document, component, default_from_dict, default_to_dict

from astra_haystack.document_store import AstraDocumentStore
from haystack_integrations.document_stores.astra import AstraDocumentStore


@component
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: 2023-present Anant Corporation <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from astra_haystack.document_store import AstraDocumentStore
from .document_store import AstraDocumentStore

__all__ = ["AstraDocumentStore"]
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from haystack.document_stores.errors import DuplicateDocumentError, MissingDocumentError
from haystack.document_stores.types import DuplicatePolicy

from astra_haystack.astra_client import AstraClient
from astra_haystack.errors import AstraDocumentStoreFilterError
from astra_haystack.filters import _convert_filters
from .astra_client import AstraClient
from .errors import AstraDocumentStoreFilterError
from .filters import _convert_filters

logger = logging.getLogger(__name__)

Expand All @@ -40,7 +40,7 @@ def __init__(
astra_application_token: str,
astra_keyspace: str,
astra_collection: str,
embedding_dim: Optional[int] = 768,
embedding_dim: int = 768,
duplicates_policy: DuplicatePolicy = DuplicatePolicy.NONE,
similarity: str = "cosine",
):
Expand Down Expand Up @@ -104,17 +104,12 @@ def to_dict(self) -> Dict[str, Any]:
def write_documents(
self,
documents: List[Document],
index: Optional[str] = None,
batch_size: int = 20,
policy: DuplicatePolicy = DuplicatePolicy.NONE,
):
"""
Indexes documents for later queries.
:param documents: a list of Haystack Document objects.
:param index: Optional name of index where the documents shall be written to.
If None, the DocumentStore's default index (self.index) will be used.
:param batch_size: Number of documents that are passed to bulk function at a time.
:param policy: Handle duplicate documents based on DuplicatePolicy parameter options.
Parameter options : (SKIP, OVERWRITE, FAIL, NONE)
- `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists,
Expand All @@ -125,26 +120,13 @@ def write_documents(
- `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised.
:return: int
"""

if index is None and self.index is None:
msg = "No Astra client provided"
raise ValueError(msg)

if index is None:
index = self.index

if policy is None or policy == DuplicatePolicy.NONE:
if self.duplicates_policy is not None and self.duplicates_policy != DuplicatePolicy.NONE:
policy = self.duplicates_policy
else:
policy = DuplicatePolicy.SKIP

if batch_size > MAX_BATCH_SIZE:
logger.warning(
f"batch_size set to {batch_size}, "
f"but maximum batch_size for Astra when using the JSON API is 20. batch_size set to 20."
)
batch_size = MAX_BATCH_SIZE
batch_size = MAX_BATCH_SIZE

def _convert_input_document(document: Union[dict, Document]):
if isinstance(document, Document):
Expand Down Expand Up @@ -196,7 +178,7 @@ def _convert_input_document(document: Union[dict, Document]):
if policy == DuplicatePolicy.SKIP:
if len(new_documents) > 0:
for batch in _batches(new_documents, batch_size):
inserted_ids = index.insert(batch) # type: ignore
inserted_ids = self.index.insert(batch) # type: ignore
insertion_counter += len(inserted_ids)
logger.info(f"write_documents inserted documents with id {inserted_ids}")
else:
Expand All @@ -205,7 +187,7 @@ def _convert_input_document(document: Union[dict, Document]):
elif policy == DuplicatePolicy.OVERWRITE:
if len(new_documents) > 0:
for batch in _batches(new_documents, batch_size):
inserted_ids = index.insert(batch) # type: ignore
inserted_ids = self.index.insert(batch) # type: ignore
insertion_counter += len(inserted_ids)
logger.info(f"write_documents inserted documents with id {inserted_ids}")
else:
Expand All @@ -214,7 +196,7 @@ def _convert_input_document(document: Union[dict, Document]):
if len(duplicate_documents) > 0:
updated_ids = []
for duplicate_doc in duplicate_documents:
updated = index.update_document(duplicate_doc, "_id") # type: ignore
updated = self.index.update_document(duplicate_doc, "_id") # type: ignore
if updated:
updated_ids.append(duplicate_doc["_id"])
insertion_counter = insertion_counter + len(updated_ids)
Expand All @@ -225,7 +207,7 @@ def _convert_input_document(document: Union[dict, Document]):
elif policy == DuplicatePolicy.FAIL:
if len(new_documents) > 0:
for batch in _batches(new_documents, batch_size):
inserted_ids = index.insert(batch) # type: ignore
inserted_ids = self.index.insert(batch) # type: ignore
insertion_counter = insertion_counter + len(inserted_ids)
logger.info(f"write_documents inserted documents with id {inserted_ids}")
else:
Expand Down
2 changes: 1 addition & 1 deletion integrations/astra/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from haystack.document_stores.types import DuplicatePolicy

from astra_haystack.document_store import AstraDocumentStore
from haystack_integrations.document_stores.astra import AstraDocumentStore


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion integrations/astra/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from haystack.document_stores.types import DuplicatePolicy
from haystack.testing.document_store import DocumentStoreBaseTests

from astra_haystack.document_store import AstraDocumentStore
from haystack_integrations.document_stores.astra import AstraDocumentStore


@pytest.mark.skipif(
Expand Down
10 changes: 5 additions & 5 deletions integrations/astra/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from astra_haystack.retriever import AstraRetriever
from haystack_integrations.components.retrievers.astra import AstraRetriever


@pytest.mark.skipif(
Expand All @@ -16,7 +16,7 @@
def test_retriever_to_json(document_store):
retriever = AstraRetriever(document_store, filters={"foo": "bar"}, top_k=99)
assert retriever.to_dict() == {
"type": "astra_haystack.retriever.AstraRetriever",
"type": "haystack_integrations.components.retrievers.astra.retriever.AstraRetriever",
"init_parameters": {
"filters": {"foo": "bar"},
"top_k": 99,
Expand All @@ -30,7 +30,7 @@ def test_retriever_to_json(document_store):
"embedding_dim": 768,
"similarity": "cosine",
},
"type": "astra_haystack.document_store.AstraDocumentStore",
"type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore",
},
},
}
Expand All @@ -43,7 +43,7 @@ def test_retriever_to_json(document_store):
@pytest.mark.integration
def test_retriever_from_json():
data = {
"type": "astra_haystack.retriever.AstraRetriever",
"type": "haystack_integrations.components.retrievers.astra.retriever.AstraRetriever",
"init_parameters": {
"filters": {"bar": "baz"},
"top_k": 42,
Expand All @@ -58,7 +58,7 @@ def test_retriever_from_json():
"embedding_dim": 768,
"similarity": "cosine",
},
"type": "astra_haystack.document_store.AstraDocumentStore",
"type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore",
},
},
}
Expand Down

0 comments on commit df86747

Please sign in to comment.