Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: extend OpenSearch params support #70

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,22 @@ def __init__(
fuzziness: str = "AUTO",
top_k: int = 10,
scale_score: bool = False,
all_terms_must_match: bool = False,
):
"""
Create the OpenSearchBM25Retriever component.

:param document_store: An instance of OpenSearchDocumentStore.
:param filters: Filters applied to the retrieved Documents. Defaults to None.
:param fuzziness: Fuzziness parameter for full-text queries. Defaults to "AUTO".
:param top_k: Maximum number of Documents to return, defaults to 10
:param scale_score: Whether to scale the score of retrieved documents between 0 and 1.
This is useful when comparing documents across different indexes. Defaults to False.
:param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents.
This is useful when searching for short text where even one term can make a difference. Defaults to False.
:raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore.

"""
if not isinstance(document_store, OpenSearchDocumentStore):
msg = "document_store must be an instance of OpenSearchDocumentStore"
raise ValueError(msg)
Expand All @@ -26,6 +41,7 @@ def __init__(
self._fuzziness = fuzziness
self._top_k = top_k
self._scale_score = scale_score
self._all_terms_must_match = all_terms_must_match

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
Expand All @@ -45,12 +61,44 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchBM25Retriever":
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query: str):
def run(
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
all_terms_must_match: Optional[bool] = None,
top_k: Optional[int] = None,
fuzziness: Optional[str] = None,
scale_score: Optional[bool] = None,
):
"""
Retrieve documents using BM25 retrieval.

:param query: The query string
:param filters: Optional filters to narrow down the search space.
:param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents.
:param top_k: Maximum number of Documents to return.
:param fuzziness: Fuzziness parameter for full-text queries.
:param scale_score: Whether to scale the score of retrieved documents between 0 and 1.
This is useful when comparing documents across different indexes.
:return: A dictionary containing the retrieved documents.
"""
if filters is None:
filters = self._filters
if all_terms_must_match is None:
all_terms_must_match = self._all_terms_must_match
if top_k is None:
top_k = self._top_k
if fuzziness is None:
fuzziness = self._fuzziness
if scale_score is None:
scale_score = self._scale_score

docs = self._document_store._bm25_retrieval(
query=query,
filters=self._filters,
fuzziness=self._fuzziness,
top_k=self._top_k,
scale_score=self._scale_score,
filters=filters,
fuzziness=fuzziness,
top_k=top_k,
scale_score=scale_score,
all_terms_must_match=all_terms_must_match,
)
return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def _bm25_retrieval(
fuzziness: str = "AUTO",
top_k: int = 10,
scale_score: bool = False,
all_terms_must_match: bool = False,
) -> List[Document]:
"""
OpenSearch by defaults uses BM25 search algorithm.
Expand All @@ -234,13 +235,13 @@ def _bm25_retrieval(
`query` must be a non empty string, otherwise a `ValueError` will be raised.

:param query: String to search in saved Documents' text.
:param filters: Filters applied to the retrieved Documents, for more info
see `OpenSearchDocumentStore.filter_documents`, defaults to None
:param filters: Optional filters to narrow down the search space.
:param fuzziness: Fuzziness parameter passed to OpenSearch, defaults to "AUTO".
see the official documentation for valid values:
https://www.elastic.co/guide/en/OpenSearch/reference/current/common-options.html#fuzziness
:param top_k: Maximum number of Documents to return, defaults to 10
:param scale_score: If `True` scales the Document`s scores between 0 and 1, defaults to False
:param all_terms_must_match: If `True` all terms in `query` must be present in the Document, defaults to False
:raises ValueError: If `query` is an empty string
:return: List of Document that match `query`
"""
Expand All @@ -249,6 +250,7 @@ def _bm25_retrieval(
msg = "query must be a non empty string"
raise ValueError(msg)

operator = "AND" if all_terms_must_match else "OR"
body: Dict[str, Any] = {
"size": top_k,
"query": {
Expand All @@ -259,7 +261,7 @@ def _bm25_retrieval(
"query": query,
"fuzziness": fuzziness,
"type": "most_fields",
"operator": "AND",
"operator": operator,
}
}
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchEmbeddingRetriever":
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query_embedding: List[float]):
def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
"""
Retrieve documents using a vector similarity metric.

:param query_embedding: Embedding of the query.
:param filters: Optional filters to narrow down the search space.
:param top_k: Maximum number of Documents to return.
:return: List of Document similar to `query_embedding`.
"""
if filters is None:
filters = self._filters
if top_k is None:
top_k = self._top_k

docs = self._document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=self._filters,
top_k=self._top_k,
filters=filters,
top_k=top_k,
)
return {"documents": docs}
58 changes: 58 additions & 0 deletions integrations/opensearch/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,64 @@ def test_run():
fuzziness="AUTO",
top_k=10,
scale_score=False,
all_terms_must_match=False,
)
assert len(res) == 1
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"


def test_run_init_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._bm25_retrieval.return_value = [Document(content="Test doc")]
retriever = OpenSearchBM25Retriever(
document_store=mock_store,
filters={"from": "init"},
all_terms_must_match=True,
scale_score=True,
top_k=11,
fuzziness="1",
)
res = retriever.run(query="some query")
mock_store._bm25_retrieval.assert_called_once_with(
query="some query",
filters={"from": "init"},
fuzziness="1",
top_k=11,
scale_score=True,
all_terms_must_match=True,
)
assert len(res) == 1
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"


def test_run_time_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._bm25_retrieval.return_value = [Document(content="Test doc")]
retriever = OpenSearchBM25Retriever(
document_store=mock_store,
filters={"from": "init"},
all_terms_must_match=True,
scale_score=True,
top_k=11,
fuzziness="1",
)
res = retriever.run(
query="some query",
filters={"from": "run"},
all_terms_must_match=False,
scale_score=False,
top_k=9,
fuzziness="2",
)
mock_store._bm25_retrieval.assert_called_once_with(
query="some query",
filters={"from": "run"},
fuzziness="2",
top_k=9,
scale_score=False,
all_terms_must_match=False,
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand Down
46 changes: 46 additions & 0 deletions integrations/opensearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,52 @@ def test_bm25_retrieval_pagination(self, document_store: OpenSearchDocumentStore
assert len(res) == 11
assert all("programming" in doc.content for doc in res)

def test_bm25_retrieval_all_terms_must_match(self, document_store: OpenSearchDocumentStore):
document_store.write_documents(
[
Document(content="Haskell is a functional programming language"),
Document(content="Lisp is a functional programming language"),
Document(content="Exilir is a functional programming language"),
Document(content="F# is a functional programming language"),
Document(content="C# is a functional programming language"),
Document(content="C++ is an object oriented programming language"),
Document(content="Dart is an object oriented programming language"),
Document(content="Go is an object oriented programming language"),
Document(content="Python is a object oriented programming language"),
Document(content="Ruby is a object oriented programming language"),
Document(content="PHP is a object oriented programming language"),
]
)

res = document_store._bm25_retrieval("functional Haskell", top_k=3, all_terms_must_match=True)
assert len(res) == 1
assert "Haskell is a functional programming language" in res[0].content

def test_bm25_retrieval_all_terms_must_match_false(self, document_store: OpenSearchDocumentStore):
document_store.write_documents(
[
Document(content="Haskell is a functional programming language"),
Document(content="Lisp is a functional programming language"),
Document(content="Exilir is a functional programming language"),
Document(content="F# is a functional programming language"),
Document(content="C# is a functional programming language"),
Document(content="C++ is an object oriented programming language"),
Document(content="Dart is an object oriented programming language"),
Document(content="Go is an object oriented programming language"),
Document(content="Python is a object oriented programming language"),
Document(content="Ruby is a object oriented programming language"),
Document(content="PHP is a object oriented programming language"),
]
)

res = document_store._bm25_retrieval("functional Haskell", top_k=10, all_terms_must_match=False)
assert len(res) == 5
assert "functional" in res[0].content
assert "functional" in res[1].content
assert "functional" in res[2].content
assert "functional" in res[3].content
assert "functional" in res[4].content

def test_bm25_retrieval_with_fuzziness(self, document_store: OpenSearchDocumentStore):
document_store.write_documents(
[
Expand Down
32 changes: 32 additions & 0 deletions integrations/opensearch/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,35 @@ def test_run():
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"
assert res["documents"][0].embedding == [0.1, 0.2]


def test_run_init_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11)
res = retriever.run(query_embedding=[0.5, 0.7])
mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.5, 0.7],
filters={"from": "init"},
top_k=11,
)
assert len(res) == 1
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"
assert res["documents"][0].embedding == [0.1, 0.2]


def test_run_time_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11)
res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9)
mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.5, 0.7],
filters={"from": "run"},
top_k=9,
)
assert len(res) == 1
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"
assert res["documents"][0].embedding == [0.1, 0.2]