From b22014b915d5022d601ea0bbfc33b03d2a206d5f Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Wed, 18 Sep 2024 13:05:31 +0200 Subject: [PATCH] fix: Prevent `set_output_types` from being called when the `output_types` decorator is used (#8376) --- .../classifiers/zero_shot_document_classifier.py | 3 +-- haystack/core/component/component.py | 7 +++++++ ...-set-output-type-override-852a19b3f0621fb0.yaml | 4 ++++ test/core/component/test_component.py | 14 ++++++++++++++ 4 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 releasenotes/notes/component-set-output-type-override-852a19b3f0621fb0.yaml diff --git a/haystack/components/classifiers/zero_shot_document_classifier.py b/haystack/components/classifiers/zero_shot_document_classifier.py index cff245b35b..5aa52fde80 100644 --- a/haystack/components/classifiers/zero_shot_document_classifier.py +++ b/haystack/components/classifiers/zero_shot_document_classifier.py @@ -121,7 +121,6 @@ def __init__( self.token = token self.labels = labels self.multi_label = multi_label - component.set_output_types(self, **{label: str for label in labels}) huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs( huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {}, @@ -229,7 +228,7 @@ def run(self, documents: List[Document], batch_size: int = 1): ) texts = [ - doc.content if self.classification_field is None else doc.meta[self.classification_field] + (doc.content if self.classification_field is None else doc.meta[self.classification_field]) for doc in documents ] diff --git a/haystack/core/component/component.py b/haystack/core/component/component.py index e34efdb28e..72c444fdd6 100644 --- a/haystack/core/component/component.py +++ b/haystack/core/component/component.py @@ -449,6 +449,13 @@ def run(self, value: int): return {"output_1": 1, "output_2": "2"} ``` """ + has_decorator = hasattr(instance.run, "_output_types_cache") + if has_decorator: + raise ComponentError( + "Cannot call `set_output_types` on a component that already has " + "the 'output_types' decorator on its `run` method" + ) + instance.__haystack_output__ = Sockets( instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket ) diff --git a/releasenotes/notes/component-set-output-type-override-852a19b3f0621fb0.yaml b/releasenotes/notes/component-set-output-type-override-852a19b3f0621fb0.yaml new file mode 100644 index 0000000000..2a06faddeb --- /dev/null +++ b/releasenotes/notes/component-set-output-type-override-852a19b3f0621fb0.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Prevent `set_output_types`` from being called when the `output_types`` decorator is used. diff --git a/test/core/component/test_component.py b/test/core/component/test_component.py index 320e2d65ca..8b4266dbb3 100644 --- a/test/core/component/test_component.py +++ b/test/core/component/test_component.py @@ -315,6 +315,20 @@ def from_dict(cls, data): return cls() +def test_output_types_decorator_and_set_output_types(): + @component + class MockComponent: + def __init__(self) -> None: + component.set_output_types(self, value=int) + + @component.output_types(value=int) + def run(self, value: int): + return {"value": 1} + + with pytest.raises(ComponentError, match="Cannot call `set_output_types`"): + comp = MockComponent() + + def test_output_types_decorator_mismatch_run_async_run(): @component class MockComponent: