Skip to content

Commit

Permalink
simplify tests - use mocks
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Nov 13, 2024
1 parent ad44c12 commit 225f436
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
self.local_files_only = local_files_only
self.meta_fields_to_embed = meta_fields_to_embed or []
self.meta_data_separator = meta_data_separator
self._model = None

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -112,8 +113,8 @@ def warm_up(self):
"""
Initializes the component.
"""
if not hasattr(self, "ranker"):
self.ranker = TextCrossEncoder(
if self._model is None:
self._model = TextCrossEncoder(
model_name=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
Expand Down Expand Up @@ -170,14 +171,14 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None
msg = f"top_k must be > 0, but got {top_k}"
raise ValueError(msg)

if not hasattr(self, "ranker"):
if self._model is None:
msg = "The ranker model has not been loaded. Please call warm_up() before running."
raise RuntimeError(msg)

fastembed_input_docs = self._prepare_fastembed_input_docs(documents)

scores = list(
self.ranker.rerank(
self._model.rerank(
query=query,
documents=fastembed_input_docs,
batch_size=self.batch_size,
Expand Down
9 changes: 4 additions & 5 deletions integrations/fastembed/tests/test_fastembed_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_run_incorrect_input_format(self):
Test for checking incorrect input format.
"""
ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2")
ranker.warm_up()
ranker._model = "mock_model"

query = "query"
string_input = "text"
Expand Down Expand Up @@ -222,7 +222,6 @@ def test_run_no_warmup(self):

with pytest.raises(
RuntimeError,
match=r"The ranker model has not been loaded. Please call warm_up\(\) before running.",
):
ranker.run(query=query, documents=list_document)

Expand All @@ -231,7 +230,7 @@ def test_run_empty_document_list(self):
Test for no error when sending no documents.
"""
ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2")
ranker.warm_up()
ranker._model = "mock_model"

query = "query"
list_document = []
Expand All @@ -247,13 +246,13 @@ def test_embed_metadata(self):
model_name="model_name",
meta_fields_to_embed=["meta_field"],
)
ranker.ranker = MagicMock()
ranker._model = MagicMock()

documents = [Document(content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
query = "test"
ranker.run(query=query, documents=documents)

ranker.ranker.rerank.assert_called_once_with(
ranker._model.rerank.assert_called_once_with(
query=query,
documents=[
"meta_value 0\ndocument-number 0",
Expand Down

0 comments on commit 225f436

Please sign in to comment.