Skip to content

Commit

Permalink
fix: Prevent set_output_types from being called when the `output_ty…
Browse files Browse the repository at this point in the history
…pes` decorator is used (#8376)
  • Loading branch information
shadeMe authored Sep 18, 2024
1 parent 117c298 commit b22014b
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def __init__(
self.token = token
self.labels = labels
self.multi_label = multi_label
component.set_output_types(self, **{label: str for label in labels})

huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
Expand Down Expand Up @@ -229,7 +228,7 @@ def run(self, documents: List[Document], batch_size: int = 1):
)

texts = [
doc.content if self.classification_field is None else doc.meta[self.classification_field]
(doc.content if self.classification_field is None else doc.meta[self.classification_field])
for doc in documents
]

Expand Down
7 changes: 7 additions & 0 deletions haystack/core/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,13 @@ def run(self, value: int):
return {"output_1": 1, "output_2": "2"}
```
"""
has_decorator = hasattr(instance.run, "_output_types_cache")
if has_decorator:
raise ComponentError(
"Cannot call `set_output_types` on a component that already has "
"the 'output_types' decorator on its `run` method"
)

instance.__haystack_output__ = Sockets(
instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Prevent `set_output_types`` from being called when the `output_types`` decorator is used.
14 changes: 14 additions & 0 deletions test/core/component/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,20 @@ def from_dict(cls, data):
return cls()


def test_output_types_decorator_and_set_output_types():
@component
class MockComponent:
def __init__(self) -> None:
component.set_output_types(self, value=int)

@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}

with pytest.raises(ComponentError, match="Cannot call `set_output_types`"):
comp = MockComponent()


def test_output_types_decorator_mismatch_run_async_run():
@component
class MockComponent:
Expand Down

0 comments on commit b22014b

Please sign in to comment.