Skip to content

Commit

Permalink
feat: Extend core component machinery to support an async run method …
Browse files Browse the repository at this point in the history
…(experimental) (#8279)

* feat: Extend core component machinery to support an async run method

* Add reno

* Fix incorrect docstring

* Make `async_run` a coroutine

* Make `supports_async` a dunder field
  • Loading branch information
shadeMe authored Aug 27, 2024
1 parent 1fa30d4 commit f0b45c8
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 31 deletions.
111 changes: 80 additions & 31 deletions haystack/core/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from copy import deepcopy
from dataclasses import dataclass
from types import new_class
from typing import Any, Dict, Optional, Protocol, runtime_checkable
from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable

from haystack import logging
from haystack.core.errors import ComponentError
Expand Down Expand Up @@ -166,7 +166,7 @@ def run(self, **kwargs):

class ComponentMeta(type):
@staticmethod
def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
"""
Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
"""
Expand All @@ -184,6 +184,66 @@ def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
out[name] = arg
return out

@staticmethod
def _parse_and_set_output_sockets(instance: Any):
has_async_run = hasattr(instance, "async_run")

# If `component.set_output_types()` was called in the component constructor,
# `__haystack_output__` is already populated, no need to do anything.
if not hasattr(instance, "__haystack_output__"):
# If that's not the case, we need to populate `__haystack_output__`
#
# If either of the run methods were decorated, they'll have a field assigned that
# stores the output specification. If both run methods were decorated, we ensure that
# outputs are the same. We deepcopy the content of the cache to transfer ownership from
# the class method to the actual instance, so that different instances of the same class
# won't share this data.

run_output_types = getattr(instance.run, "_output_types_cache", {})
async_run_output_types = getattr(instance.async_run, "_output_types_cache", {}) if has_async_run else {}

if has_async_run and run_output_types != async_run_output_types:
raise ComponentError("Output type specifications of 'run' and 'async_run' methods must be the same")
output_types_cache = run_output_types

instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket)

@staticmethod
def _parse_and_set_input_sockets(component_cls: Type, instance: Any):
def inner(method, sockets):
run_signature = inspect.signature(method)

# First is 'self' and it doesn't matter.
for param in list(run_signature.parameters)[1:]:
if run_signature.parameters[param].kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
): # ignore variable args
socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation}
if run_signature.parameters[param].default != inspect.Parameter.empty:
socket_kwargs["default_value"] = run_signature.parameters[param].default
sockets[param] = InputSocket(**socket_kwargs)

# Create the sockets if set_input_types() wasn't called in the constructor.
# If it was called and there are some parameters also in the `run()` method, these take precedence.
if not hasattr(instance, "__haystack_input__"):
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)

inner(getattr(component_cls, "run"), instance.__haystack_input__)

# Ensure that the sockets are the same for the async method, if it exists.
async_run = getattr(component_cls, "async_run", None)
if async_run is not None:
run_sockets = Sockets(instance, {}, InputSocket)
async_run_sockets = Sockets(instance, {}, InputSocket)

# Can't use the sockets from above as they might contain
# values set with set_input_types().
inner(getattr(component_cls, "run"), run_sockets)
inner(async_run, async_run_sockets)
if async_run_sockets != run_sockets:
raise ComponentError("Parameters of 'run' and 'async_run' methods must be the same")

def __call__(cls, *args, **kwargs):
"""
This method is called when clients instantiate a Component and runs before __new__ and __init__.
Expand All @@ -195,7 +255,7 @@ def __call__(cls, *args, **kwargs):
else:
try:
pre_init_hook.in_progress = True
named_positional_args = ComponentMeta.positional_to_kwargs(cls, args)
named_positional_args = ComponentMeta._positional_to_kwargs(cls, args)
assert (
set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
), "positional and keyword arguments overlap"
Expand All @@ -207,34 +267,13 @@ def __call__(cls, *args, **kwargs):

# Before returning, we have the chance to modify the newly created
# Component instance, so we take the chance and set up the I/O sockets
has_async_run = hasattr(instance, "async_run")
if has_async_run and not inspect.iscoroutinefunction(instance.async_run):
raise ComponentError(f"Method 'async_run' of component '{cls.__name__}' must be a coroutine")
instance.__haystack_supports_async__ = has_async_run

# If `component.set_output_types()` was called in the component constructor,
# `__haystack_output__` is already populated, no need to do anything.
if not hasattr(instance, "__haystack_output__"):
# If that's not the case, we need to populate `__haystack_output__`
#
# If the `run` method was decorated, it has a `_output_types_cache` field assigned
# that stores the output specification.
# We deepcopy the content of the cache to transfer ownership from the class method
# to the actual instance, so that different instances of the same class won't share this data.
instance.__haystack_output__ = Sockets(
instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket
)

# Create the sockets if set_input_types() wasn't called in the constructor.
# If it was called and there are some parameters also in the `run()` method, these take precedence.
if not hasattr(instance, "__haystack_input__"):
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
run_signature = inspect.signature(getattr(cls, "run"))
for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter.
if run_signature.parameters[param].kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
): # ignore variable args
socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation}
if run_signature.parameters[param].default != inspect.Parameter.empty:
socket_kwargs["default_value"] = run_signature.parameters[param].default
instance.__haystack_input__[param] = InputSocket(**socket_kwargs)
ComponentMeta._parse_and_set_input_sockets(cls, instance)
ComponentMeta._parse_and_set_output_sockets(instance)

# Since a Component can't be used in multiple Pipelines at the same time
# we need to know if it's already owned by a Pipeline when adding it to one.
Expand Down Expand Up @@ -290,7 +329,13 @@ class _Component:
def __init__(self):
self.registry = {}

def set_input_type(self, instance, name: str, type: Any, default: Any = _empty): # noqa: A002
def set_input_type(
self,
instance,
name: str,
type: Any, # noqa: A002
default: Any = _empty,
):
"""
Add a single input socket to the component instance.
Expand Down Expand Up @@ -395,6 +440,10 @@ class available here, we temporarily store the output types as an attribute of
the decorated method. The ComponentMeta metaclass will use this data to create
sockets at instance creation time.
"""
method_name = run_method.__name__
if method_name not in ("run", "async_run"):
raise ComponentError("'output_types' decorator can only be used on 'run' and `async_run` methods")

setattr(
run_method,
"_output_types_cache",
Expand Down
10 changes: 10 additions & 0 deletions haystack/core/component/sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ def __init__(
self._sockets_dict = sockets_dict
self.__dict__.update(sockets_dict)

def __eq__(self, value: object) -> bool:
if not isinstance(value, Sockets):
return False

return (
self._sockets_io_type == value._sockets_io_type
and self._component == value._component
and self._sockets_dict == value._sockets_dict
)

def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]):
"""
Adds a new socket to this Sockets object.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Extend core component machinery to support an optional asynchronous `async_run` method in components.
If it's present, it should have the same parameters (and output types) as the run method and must be
implemented as a coroutine.
117 changes: 117 additions & 0 deletions test/core/component/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@ def run(self, input_value: int):
# Verifies also instantiation works with no issues
assert MockComponent()
assert component.registry["test_component.MockComponent"] == MockComponent
assert isinstance(MockComponent(), Component)
assert MockComponent().__haystack_supports_async__ is False


def test_correct_declaration_with_async():
@component
class MockComponent:
def to_dict(self):
return {}

@classmethod
def from_dict(cls, data):
return cls()

@component.output_types(output_value=int)
def run(self, input_value: int):
return {"output_value": input_value}

@component.output_types(output_value=int)
async def async_run(self, input_value: int):
return {"output_value": input_value}

# Verifies also instantiation works with no issues
assert MockComponent()
assert component.registry["test_component.MockComponent"] == MockComponent
assert isinstance(MockComponent(), Component)
assert MockComponent().__haystack_supports_async__ is True


def test_correct_declaration_with_additional_readonly_property():
Expand Down Expand Up @@ -95,6 +122,50 @@ def another_method(self, input_value: int):
return {"output_value": input_value}


def test_async_run_not_async():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}

@component.output_types(value=int)
def async_run(self, value: int):
return {"value": 1}

with pytest.raises(ComponentError):
comp = MockComponent()


def test_async_run_not_coroutine():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}

@component.output_types(value=int)
async def async_run(self, value: int):
yield {"value": 1}

with pytest.raises(ComponentError):
comp = MockComponent()


def test_parameters_mismatch_run_and_async_run():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}

async def async_run(self, value: str):
yield {"value": "1"}

with pytest.raises(ComponentError):
comp = MockComponent()


def test_set_input_types():
@component
class MockComponent:
Expand Down Expand Up @@ -155,6 +226,52 @@ def from_dict(cls, data):
assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}


def test_output_types_decorator_wrong_method():
with pytest.raises(ComponentError):

@component
class MockComponent:
def run(self, value: int):
return {"value": 1}

@component.output_types(value=int)
def to_dict(self):
return {}

@classmethod
def from_dict(cls, data):
return cls()


def test_output_types_decorator_mismatch_run_async_run():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}

@component.output_types(value=str)
async def async_run(self, value: int):
return {"value": "1"}

with pytest.raises(ComponentError):
comp = MockComponent()


def test_output_types_decorator_missing_async_run():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}

async def async_run(self, value: int):
return {"value": "1"}

with pytest.raises(ComponentError):
comp = MockComponent()


def test_component_decorator_set_it_as_component():
@component
class MockComponent:
Expand Down

0 comments on commit f0b45c8

Please sign in to comment.