From 9f4e1ecbb7d0864e9ec6842e61d32949bcb64a16 Mon Sep 17 00:00:00 2001 From: Sebastian Weisshaar <76220851+sebastian-weisshaar@users.noreply.github.com> Date: Mon, 19 Feb 2024 12:54:16 +0100 Subject: [PATCH] Feat: Add filters to run function in retrievers of elasticsearch (#440) * feat: add filters to run function of bm_25 retriever in elastic search * feat: add filters to run function of embedding retriever in elastic search * docs: add docstring for filters in run function of retrievers in elasticsearch --- .../components/retrievers/elasticsearch/bm25_retriever.py | 5 +++-- .../retrievers/elasticsearch/embedding_retriever.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/bm25_retriever.py b/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/bm25_retriever.py index bd96a5fd8..0416389a2 100644 --- a/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/bm25_retriever.py +++ b/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/bm25_retriever.py @@ -89,17 +89,18 @@ def from_dict(cls, data: Dict[str, Any]) -> "ElasticsearchBM25Retriever": return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run(self, query: str, top_k: Optional[int] = None): + def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): """ Retrieve documents using the BM25 keyword-based algorithm. :param query: String to search in Documents' text. + :param filters: Filters applied to the retrieved Documents. :param top_k: Maximum number of Documents to return. :return: List of Documents that match the query. """ docs = self._document_store._bm25_retrieval( query=query, - filters=self._filters, + filters=filters or self._filters, fuzziness=self._fuzziness, top_k=top_k or self._top_k, scale_score=self._scale_score, diff --git a/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/embedding_retriever.py b/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/embedding_retriever.py index a2c825d66..2ba68916f 100644 --- a/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/embedding_retriever.py +++ b/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/embedding_retriever.py @@ -63,17 +63,18 @@ def from_dict(cls, data: Dict[str, Any]) -> "ElasticsearchEmbeddingRetriever": return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run(self, query_embedding: List[float], top_k: Optional[int] = None): + def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): """ Retrieve documents using a vector similarity metric. :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. :param top_k: Maximum number of Documents to return. :return: List of Documents similar to `query_embedding`. """ docs = self._document_store._embedding_retrieval( query_embedding=query_embedding, - filters=self._filters, + filters=filters or self._filters, top_k=top_k or self._top_k, num_candidates=self._num_candidates, )