diff --git a/haystack/core/component/component.py b/haystack/core/component/component.py index 12e3643c1e..e5036554ff 100644 --- a/haystack/core/component/component.py +++ b/haystack/core/component/component.py @@ -77,7 +77,7 @@ from copy import deepcopy from dataclasses import dataclass from types import new_class -from typing import Any, Dict, Optional, Protocol, runtime_checkable +from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable from haystack import logging from haystack.core.errors import ComponentError @@ -166,7 +166,7 @@ def run(self, **kwargs): class ComponentMeta(type): @staticmethod - def positional_to_kwargs(cls_type, args) -> Dict[str, Any]: + def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]: """ Convert positional arguments to keyword arguments based on the signature of the `__init__` method. """ @@ -184,6 +184,66 @@ def positional_to_kwargs(cls_type, args) -> Dict[str, Any]: out[name] = arg return out + @staticmethod + def _parse_and_set_output_sockets(instance: Any): + has_async_run = hasattr(instance, "async_run") + + # If `component.set_output_types()` was called in the component constructor, + # `__haystack_output__` is already populated, no need to do anything. + if not hasattr(instance, "__haystack_output__"): + # If that's not the case, we need to populate `__haystack_output__` + # + # If either of the run methods were decorated, they'll have a field assigned that + # stores the output specification. If both run methods were decorated, we ensure that + # outputs are the same. We deepcopy the content of the cache to transfer ownership from + # the class method to the actual instance, so that different instances of the same class + # won't share this data. + + run_output_types = getattr(instance.run, "_output_types_cache", {}) + async_run_output_types = getattr(instance.async_run, "_output_types_cache", {}) if has_async_run else {} + + if has_async_run and run_output_types != async_run_output_types: + raise ComponentError("Output type specifications of 'run' and 'async_run' methods must be the same") + output_types_cache = run_output_types + + instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket) + + @staticmethod + def _parse_and_set_input_sockets(component_cls: Type, instance: Any): + def inner(method, sockets): + run_signature = inspect.signature(method) + + # First is 'self' and it doesn't matter. + for param in list(run_signature.parameters)[1:]: + if run_signature.parameters[param].kind not in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): # ignore variable args + socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation} + if run_signature.parameters[param].default != inspect.Parameter.empty: + socket_kwargs["default_value"] = run_signature.parameters[param].default + sockets[param] = InputSocket(**socket_kwargs) + + # Create the sockets if set_input_types() wasn't called in the constructor. + # If it was called and there are some parameters also in the `run()` method, these take precedence. + if not hasattr(instance, "__haystack_input__"): + instance.__haystack_input__ = Sockets(instance, {}, InputSocket) + + inner(getattr(component_cls, "run"), instance.__haystack_input__) + + # Ensure that the sockets are the same for the async method, if it exists. + async_run = getattr(component_cls, "async_run", None) + if async_run is not None: + run_sockets = Sockets(instance, {}, InputSocket) + async_run_sockets = Sockets(instance, {}, InputSocket) + + # Can't use the sockets from above as they might contain + # values set with set_input_types(). + inner(getattr(component_cls, "run"), run_sockets) + inner(async_run, async_run_sockets) + if async_run_sockets != run_sockets: + raise ComponentError("Parameters of 'run' and 'async_run' methods must be the same") + def __call__(cls, *args, **kwargs): """ This method is called when clients instantiate a Component and runs before __new__ and __init__. @@ -195,7 +255,7 @@ def __call__(cls, *args, **kwargs): else: try: pre_init_hook.in_progress = True - named_positional_args = ComponentMeta.positional_to_kwargs(cls, args) + named_positional_args = ComponentMeta._positional_to_kwargs(cls, args) assert ( set(named_positional_args.keys()).intersection(kwargs.keys()) == set() ), "positional and keyword arguments overlap" @@ -207,34 +267,13 @@ def __call__(cls, *args, **kwargs): # Before returning, we have the chance to modify the newly created # Component instance, so we take the chance and set up the I/O sockets + has_async_run = hasattr(instance, "async_run") + if has_async_run and not inspect.iscoroutinefunction(instance.async_run): + raise ComponentError(f"Method 'async_run' of component '{cls.__name__}' must be a coroutine") + instance.__haystack_supports_async__ = has_async_run - # If `component.set_output_types()` was called in the component constructor, - # `__haystack_output__` is already populated, no need to do anything. - if not hasattr(instance, "__haystack_output__"): - # If that's not the case, we need to populate `__haystack_output__` - # - # If the `run` method was decorated, it has a `_output_types_cache` field assigned - # that stores the output specification. - # We deepcopy the content of the cache to transfer ownership from the class method - # to the actual instance, so that different instances of the same class won't share this data. - instance.__haystack_output__ = Sockets( - instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket - ) - - # Create the sockets if set_input_types() wasn't called in the constructor. - # If it was called and there are some parameters also in the `run()` method, these take precedence. - if not hasattr(instance, "__haystack_input__"): - instance.__haystack_input__ = Sockets(instance, {}, InputSocket) - run_signature = inspect.signature(getattr(cls, "run")) - for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter. - if run_signature.parameters[param].kind not in ( - inspect.Parameter.VAR_POSITIONAL, - inspect.Parameter.VAR_KEYWORD, - ): # ignore variable args - socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation} - if run_signature.parameters[param].default != inspect.Parameter.empty: - socket_kwargs["default_value"] = run_signature.parameters[param].default - instance.__haystack_input__[param] = InputSocket(**socket_kwargs) + ComponentMeta._parse_and_set_input_sockets(cls, instance) + ComponentMeta._parse_and_set_output_sockets(instance) # Since a Component can't be used in multiple Pipelines at the same time # we need to know if it's already owned by a Pipeline when adding it to one. @@ -290,7 +329,13 @@ class _Component: def __init__(self): self.registry = {} - def set_input_type(self, instance, name: str, type: Any, default: Any = _empty): # noqa: A002 + def set_input_type( + self, + instance, + name: str, + type: Any, # noqa: A002 + default: Any = _empty, + ): """ Add a single input socket to the component instance. @@ -395,6 +440,10 @@ class available here, we temporarily store the output types as an attribute of the decorated method. The ComponentMeta metaclass will use this data to create sockets at instance creation time. """ + method_name = run_method.__name__ + if method_name not in ("run", "async_run"): + raise ComponentError("'output_types' decorator can only be used on 'run' and `async_run` methods") + setattr( run_method, "_output_types_cache", diff --git a/haystack/core/component/sockets.py b/haystack/core/component/sockets.py index 19a3f798f7..22289ae9f8 100644 --- a/haystack/core/component/sockets.py +++ b/haystack/core/component/sockets.py @@ -80,6 +80,16 @@ def __init__( self._sockets_dict = sockets_dict self.__dict__.update(sockets_dict) + def __eq__(self, value: object) -> bool: + if not isinstance(value, Sockets): + return False + + return ( + self._sockets_io_type == value._sockets_io_type + and self._component == value._component + and self._sockets_dict == value._sockets_dict + ) + def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]): """ Adds a new socket to this Sockets object. diff --git a/releasenotes/notes/async-component-support-machinery-6ea4496241aeb3b2.yaml b/releasenotes/notes/async-component-support-machinery-6ea4496241aeb3b2.yaml new file mode 100644 index 0000000000..bc4b8f55d0 --- /dev/null +++ b/releasenotes/notes/async-component-support-machinery-6ea4496241aeb3b2.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Extend core component machinery to support an optional asynchronous `async_run` method in components. + If it's present, it should have the same parameters (and output types) as the run method and must be + implemented as a coroutine. diff --git a/test/core/component/test_component.py b/test/core/component/test_component.py index 42d2cdb1a0..c81fbf8ae1 100644 --- a/test/core/component/test_component.py +++ b/test/core/component/test_component.py @@ -31,6 +31,33 @@ def run(self, input_value: int): # Verifies also instantiation works with no issues assert MockComponent() assert component.registry["test_component.MockComponent"] == MockComponent + assert isinstance(MockComponent(), Component) + assert MockComponent().__haystack_supports_async__ is False + + +def test_correct_declaration_with_async(): + @component + class MockComponent: + def to_dict(self): + return {} + + @classmethod + def from_dict(cls, data): + return cls() + + @component.output_types(output_value=int) + def run(self, input_value: int): + return {"output_value": input_value} + + @component.output_types(output_value=int) + async def async_run(self, input_value: int): + return {"output_value": input_value} + + # Verifies also instantiation works with no issues + assert MockComponent() + assert component.registry["test_component.MockComponent"] == MockComponent + assert isinstance(MockComponent(), Component) + assert MockComponent().__haystack_supports_async__ is True def test_correct_declaration_with_additional_readonly_property(): @@ -95,6 +122,50 @@ def another_method(self, input_value: int): return {"output_value": input_value} +def test_async_run_not_async(): + @component + class MockComponent: + @component.output_types(value=int) + def run(self, value: int): + return {"value": 1} + + @component.output_types(value=int) + def async_run(self, value: int): + return {"value": 1} + + with pytest.raises(ComponentError): + comp = MockComponent() + + +def test_async_run_not_coroutine(): + @component + class MockComponent: + @component.output_types(value=int) + def run(self, value: int): + return {"value": 1} + + @component.output_types(value=int) + async def async_run(self, value: int): + yield {"value": 1} + + with pytest.raises(ComponentError): + comp = MockComponent() + + +def test_parameters_mismatch_run_and_async_run(): + @component + class MockComponent: + @component.output_types(value=int) + def run(self, value: int): + return {"value": 1} + + async def async_run(self, value: str): + yield {"value": "1"} + + with pytest.raises(ComponentError): + comp = MockComponent() + + def test_set_input_types(): @component class MockComponent: @@ -155,6 +226,52 @@ def from_dict(cls, data): assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)} +def test_output_types_decorator_wrong_method(): + with pytest.raises(ComponentError): + + @component + class MockComponent: + def run(self, value: int): + return {"value": 1} + + @component.output_types(value=int) + def to_dict(self): + return {} + + @classmethod + def from_dict(cls, data): + return cls() + + +def test_output_types_decorator_mismatch_run_async_run(): + @component + class MockComponent: + @component.output_types(value=int) + def run(self, value: int): + return {"value": 1} + + @component.output_types(value=str) + async def async_run(self, value: int): + return {"value": "1"} + + with pytest.raises(ComponentError): + comp = MockComponent() + + +def test_output_types_decorator_missing_async_run(): + @component + class MockComponent: + @component.output_types(value=int) + def run(self, value: int): + return {"value": 1} + + async def async_run(self, value: int): + return {"value": "1"} + + with pytest.raises(ComponentError): + comp = MockComponent() + + def test_component_decorator_set_it_as_component(): @component class MockComponent: