diff --git a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py index ae43441a1..772ef5880 100644 --- a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py +++ b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py @@ -75,6 +75,7 @@ def __init__( self.local_files_only = local_files_only self.meta_fields_to_embed = meta_fields_to_embed or [] self.meta_data_separator = meta_data_separator + self._model = None def to_dict(self) -> Dict[str, Any]: """ @@ -112,8 +113,8 @@ def warm_up(self): """ Initializes the component. """ - if not hasattr(self, "ranker"): - self.ranker = TextCrossEncoder( + if self._model is None: + self._model = TextCrossEncoder( model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads, @@ -170,14 +171,14 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None msg = f"top_k must be > 0, but got {top_k}" raise ValueError(msg) - if not hasattr(self, "ranker"): + if self._model is None: msg = "The ranker model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) fastembed_input_docs = self._prepare_fastembed_input_docs(documents) scores = list( - self.ranker.rerank( + self._model.rerank( query=query, documents=fastembed_input_docs, batch_size=self.batch_size, diff --git a/integrations/fastembed/tests/test_fastembed_ranker.py b/integrations/fastembed/tests/test_fastembed_ranker.py index d2339c148..e38229c87 100644 --- a/integrations/fastembed/tests/test_fastembed_ranker.py +++ b/integrations/fastembed/tests/test_fastembed_ranker.py @@ -180,7 +180,7 @@ def test_run_incorrect_input_format(self): Test for checking incorrect input format. """ ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2") - ranker.warm_up() + ranker._model = "mock_model" query = "query" string_input = "text" @@ -222,7 +222,6 @@ def test_run_no_warmup(self): with pytest.raises( RuntimeError, - match=r"The ranker model has not been loaded. Please call warm_up\(\) before running.", ): ranker.run(query=query, documents=list_document) @@ -231,7 +230,7 @@ def test_run_empty_document_list(self): Test for no error when sending no documents. """ ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2") - ranker.warm_up() + ranker._model = "mock_model" query = "query" list_document = [] @@ -247,13 +246,13 @@ def test_embed_metadata(self): model_name="model_name", meta_fields_to_embed=["meta_field"], ) - ranker.ranker = MagicMock() + ranker._model = MagicMock() documents = [Document(content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] query = "test" ranker.run(query=query, documents=documents) - ranker.ranker.rerank.assert_called_once_with( + ranker._model.rerank.assert_called_once_with( query=query, documents=[ "meta_value 0\ndocument-number 0",