Skip to content

Commit

Permalink
fix: Lints
Browse files Browse the repository at this point in the history
  • Loading branch information
shadeMe committed Nov 20, 2024
1 parent 8b11072 commit 22a5dfb
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchBM25Retriever":

def _prepare_bm25_args(
self,
*,
query: str,
filters: Optional[Dict[str, Any]],
all_terms_must_match: Optional[bool],
Expand Down Expand Up @@ -176,7 +177,7 @@ def _prepare_bm25_args(
}

@component.output_types(documents=List[Document])
def run(
def run( # pylint: disable=too-many-positional-arguments
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -209,30 +210,29 @@ def run(
"""
docs: List[Document] = []
bm25_args = self._prepare_bm25_args(
query,
filters,
all_terms_must_match,
top_k,
fuzziness,
scale_score,
custom_query,
query=query,
filters=filters,
all_terms_must_match=all_terms_must_match,
top_k=top_k,
fuzziness=fuzziness,
scale_score=scale_score,
custom_query=custom_query,
)
try:
docs = self._document_store._bm25_retrieval(**bm25_args)
except Exception as e:
if self._raise_on_failure:
raise e
else:
logger.warning(
"An error during BM25 retrieval occurred and will be ignored by returning empty results: {error}",
error=str(e),
exc_info=True,
)
logger.warning(
"An error during BM25 retrieval occurred and will be ignored by returning empty results: {error}",
error=str(e),
exc_info=True,
)

return {"documents": docs}

@component.output_types(documents=List[Document])
async def run_async(
async def run_async( # pylint: disable=too-many-positional-arguments
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -265,24 +265,23 @@ async def run_async(
"""
docs: List[Document] = []
bm25_args = self._prepare_bm25_args(
query,
filters,
all_terms_must_match,
top_k,
fuzziness,
scale_score,
custom_query,
query=query,
filters=filters,
all_terms_must_match=all_terms_must_match,
top_k=top_k,
fuzziness=fuzziness,
scale_score=scale_score,
custom_query=custom_query,
)
try:
docs = await self._document_store._bm25_retrieval_async(**bm25_args)
except Exception as e:
if self._raise_on_failure:
raise e
else:
logger.warning(
"An error during BM25 retrieval occurred and will be ignored by returning empty results: {error}",
error=str(e),
exc_info=True,
)
logger.warning(
"An error during BM25 retrieval occurred and will be ignored by returning empty results: {error}",
error=str(e),
exc_info=True,
)

return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,12 @@ def run(
except Exception as e:
if self._raise_on_failure:
raise e
else:
logger.warning(
"An error during embedding retrieval occurred and will be "
"ignored by returning empty results: {error}",
error=str(e),
exc_info=True,
)
logger.warning(
"An error during embedding retrieval occurred and will be "
"ignored by returning empty results: {error}",
error=str(e),
exc_info=True,
)

return {"documents": docs}

Expand Down Expand Up @@ -193,12 +192,11 @@ async def run_async(
except Exception as e:
if self._raise_on_failure:
raise e
else:
logger.warning(
"An error during embedding retrieval occurred and will be "
"ignored by returning empty results: {error}",
error=str(e),
exc_info=True,
)
logger.warning(
"An error during embedding retrieval occurred and will be "
"ignored by returning empty results: {error}",
error=str(e),
exc_info=True,
)

return {"documents": docs}
4 changes: 2 additions & 2 deletions haystack_experimental/core/pipeline/async_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async def _run_component(

return res

async def _run_subgraph( # noqa: PLR0915, PLR0912
async def _run_subgraph( # noqa: PLR0915, PLR0912 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
self,
cycle: List[str],
component_name: str,
Expand Down Expand Up @@ -288,7 +288,7 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912

yield subgraph_outputs, True

async def run( # noqa: PLR0915
async def run( # noqa: PLR0915, PLR0912 # pylint: disable=too-many-locals, too-many-branches, too-many-statements
self,
data: Dict[str, Any],
) -> AsyncIterator[Dict[str, Any]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class InMemoryDocumentStore(InMemoryDocumentStoreBase):
Asynchronous version of the in-memory document store.
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
bm25_tokenization_regex: str = r"(?u)\b\w\w+\b",
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
Expand Down Expand Up @@ -126,7 +126,7 @@ async def bm25_retrieval_async(
lambda: self.bm25_retrieval(query=query, filters=filters, top_k=top_k, scale_score=scale_score),
)

async def embedding_retrieval_async(
async def embedding_retrieval_async( # pylint: disable=too-many-positional-arguments
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
Expand Down
17 changes: 14 additions & 3 deletions haystack_experimental/document_stores/opensearch/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


class OpenSearchDocumentStore:
def __init__(
def __init__( # pylint: disable=dangerous-default-value
self,
*,
hosts: Optional[Hosts] = None,
Expand Down Expand Up @@ -403,6 +403,7 @@ def _render_custom_query(self, custom_query: Any, substitutions: Dict[str, Any])

def _prepare_bm25_search_request(
self,
*,
query: str,
filters: Optional[Dict[str, Any]],
fuzziness: str,
Expand Down Expand Up @@ -479,7 +480,12 @@ def _bm25_retrieval(
self._ensure_initialized()

search_params = self._prepare_bm25_search_request(
query, filters, fuzziness, top_k, all_terms_must_match, custom_query
query=query,
filters=filters,
fuzziness=fuzziness,
top_k=top_k,
all_terms_must_match=all_terms_must_match,
custom_query=custom_query,
)
documents = self._search_documents(search_params)
self._postprocess_bm25_search_results(documents, scale_score)
Expand All @@ -499,7 +505,12 @@ async def _bm25_retrieval_async(
self._ensure_initialized()

search_params = self._prepare_bm25_search_request(
query, filters, fuzziness, top_k, all_terms_must_match, custom_query
query=query,
filters=filters,
fuzziness=fuzziness,
top_k=top_k,
all_terms_must_match=all_terms_must_match,
custom_query=custom_query,
)
documents = await self._search_documents_async(search_params)
self._postprocess_bm25_search_results(documents, scale_score)
Expand Down
20 changes: 10 additions & 10 deletions haystack_experimental/document_stores/types/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,23 @@ def to_dict(self) -> Dict[str, Any]:
"""
Serializes this store to a dictionary.
"""
...
pass

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DocumentStore":
"""
Deserializes the store from a dictionary.
"""
...
pass

def count_documents(self) -> int:
"""
Returns the number of documents stored.
"""
...
pass

async def count_documents_async(self) -> int: # noqa: D102
...
pass

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Expand Down Expand Up @@ -107,12 +107,12 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
:param filters: the filters to apply to the document list.
:returns: a list of Documents that match the given filters.
"""
...
pass

async def filter_documents_async( # noqa: D102
self, filters: Optional[Dict[str, Any]] = None
) -> List[Document]:
...
pass

def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
"""
Expand All @@ -130,12 +130,12 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
If `DuplicatePolicy.OVERWRITE` is used, this number is always equal to the number of documents in input.
If `DuplicatePolicy.SKIP` is used, this number can be lower than the number of documents in the input list.
"""
...
pass

async def write_documents_async( # noqa: D102
self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE
) -> int:
...
pass

def delete_documents(self, document_ids: List[str]) -> None:
"""
Expand All @@ -145,9 +145,9 @@ def delete_documents(self, document_ids: List[str]) -> None:
:param document_ids: the object_ids to delete
"""
...
pass

async def delete_documents_async( # noqa: D102
self, document_ids: List[str]
) -> None:
...
pass

0 comments on commit 22a5dfb

Please sign in to comment.