From b557f3035e2f2abad4d3eb57fe80e72beeffc68c Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 23 Nov 2023 10:28:08 +0100 Subject: [PATCH] feat: Add `ConditionalRouter` Haystack 2.x component (#6147) Co-authored-by: Massimiliano Pippi Co-authored-by: Daria Fokina --- .../preview/components/routers/__init__.py | 4 +- .../components/routers/conditional_router.py | 347 ++++++++++++++++++ .../notes/add-router-f1f0cec79b1efe9a.yaml | 6 + .../routers/test_conditional_router.py | 324 ++++++++++++++++ 4 files changed, 679 insertions(+), 2 deletions(-) create mode 100644 haystack/preview/components/routers/conditional_router.py create mode 100644 releasenotes/notes/add-router-f1f0cec79b1efe9a.yaml create mode 100644 test/preview/components/routers/test_conditional_router.py diff --git a/haystack/preview/components/routers/__init__.py b/haystack/preview/components/routers/__init__.py index 32e9a422c5..2da95625fc 100644 --- a/haystack/preview/components/routers/__init__.py +++ b/haystack/preview/components/routers/__init__.py @@ -1,7 +1,7 @@ from haystack.preview.components.routers.document_joiner import DocumentJoiner from haystack.preview.components.routers.file_type_router import FileTypeRouter from haystack.preview.components.routers.metadata_router import MetadataRouter +from haystack.preview.components.routers.conditional_router import ConditionalRouter from haystack.preview.components.routers.text_language_router import TextLanguageRouter - -__all__ = ["DocumentJoiner", "FileTypeRouter", "MetadataRouter", "TextLanguageRouter"] +__all__ = ["DocumentJoiner", "FileTypeRouter", "MetadataRouter", "TextLanguageRouter", "ConditionalRouter"] diff --git a/haystack/preview/components/routers/conditional_router.py b/haystack/preview/components/routers/conditional_router.py new file mode 100644 index 0000000000..af96d96ff3 --- /dev/null +++ b/haystack/preview/components/routers/conditional_router.py @@ -0,0 +1,347 @@ +import importlib +import inspect +import logging +import sys +from typing import List, Dict, Any, Set, get_origin + +from jinja2 import meta, Environment, TemplateSyntaxError +from jinja2.nativetypes import NativeEnvironment + +from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError + +logger = logging.getLogger(__name__) + + +class NoRouteSelectedException(Exception): + """Exception raised when no route is selected in ConditionalRouter.""" + + +class RouteConditionException(Exception): + """Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter.""" + + +def serialize_type(target: Any) -> str: + """ + Serializes a type or an instance to its string representation, including the module name. + + This function handles types, instances of types, and special typing objects. + It assumes that non-typing objects will have a '__name__' attribute and raises + an error if a type cannot be serialized. + + :param target: The object to serialize, can be an instance or a type. + :type target: Any + :return: The string representation of the type. + :raises ValueError: If the type cannot be serialized. + """ + # If the target is a string and contains a dot, treat it as an already serialized type + if isinstance(target, str) and "." in target: + return target + + # Determine if the target is a type or an instance of a typing object + is_type_or_typing = isinstance(target, type) or bool(get_origin(target)) + type_obj = target if is_type_or_typing else type(target) + module = inspect.getmodule(type_obj) + type_obj_repr = repr(type_obj) + + if type_obj_repr.startswith("typing."): + # e.g., typing.List[int] -> List[int], we'll add the module below + type_name = type_obj_repr.split(".", 1)[1] + elif hasattr(type_obj, "__name__"): + type_name = type_obj.__name__ + else: + # If type cannot be serialized, raise an error + raise ValueError(f"Could not serialize type: {type_obj_repr}") + + # Construct the full path with module name if available + if module and hasattr(module, "__name__"): + if module.__name__ == "builtins": + # omit the module name for builtins, it just clutters the output + # e.g. instead of 'builtins.str', we'll just return 'str' + full_path = type_name + else: + full_path = f"{module.__name__}.{type_name}" + else: + full_path = type_name + + return full_path + + +def deserialize_type(type_str: str) -> Any: + """ + Deserializes a type given its full import path as a string, including nested generic types. + + This function will dynamically import the module if it's not already imported + and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'. + + :param type_str: The string representation of the type's full import path. + :return: The deserialized type object. + :raises DeserializationError: If the type cannot be deserialized due to missing module or type. + """ + + def parse_generic_args(args_str): + args = [] + bracket_count = 0 + current_arg = "" + + for char in args_str: + if char == "[": + bracket_count += 1 + elif char == "]": + bracket_count -= 1 + + if char == "," and bracket_count == 0: + args.append(current_arg.strip()) + current_arg = "" + else: + current_arg += char + + if current_arg: + args.append(current_arg.strip()) + + return args + + if "[" in type_str and type_str.endswith("]"): + # Handle generics + main_type_str, generics_str = type_str.split("[", 1) + generics_str = generics_str[:-1] + + main_type = deserialize_type(main_type_str) + generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str)) + + # Reconstruct + return main_type[generic_args] + + else: + # Handle non-generics + parts = type_str.split(".") + module_name = ".".join(parts[:-1]) or "builtins" + type_name = parts[-1] + + module = sys.modules.get(module_name) + if not module: + try: + module = importlib.import_module(module_name) + except ImportError as e: + raise DeserializationError(f"Could not import the module: {module_name}") from e + + deserialized_type = getattr(module, type_name, None) + if not deserialized_type: + raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}") + + return deserialized_type + + +@component +class ConditionalRouter: + """ + ConditionalRouter in Haystack 2.x pipelines is designed to manage data routing based on specific conditions. + This is achieved by defining a list named 'routes'. Each element in this list is a dictionary representing a + single route. + + A route dictionary comprises four key elements: + - 'condition': A Jinja2 string expression that determines if the route is selected. + - 'output': A Jinja2 expression defining the route's output value. + - 'output_type': The type of the output data (e.g., str, List[int]). + - 'output_name': The name under which the `output` value of the route is published. This name is used to connect + the router to other components in the pipeline. + + Here's an example: + + ```python + from haystack.preview.components.routers import ConditionalRouter + + routes = [ + { + "condition": "{{streams|length > 2}}", + "output": "{{streams}}", + "output_name": "enough_streams", + "output_type": List[int], + }, + { + "condition": "{{streams|length <= 2}}", + "output": "{{streams}}", + "output_name": "insufficient_streams", + "output_type": List[int], + }, + ] + router = ConditionalRouter(routes) + # When 'streams' has more than 2 items, 'enough_streams' output will activate, emitting the list [1, 2, 3] + kwargs = {"streams": [1, 2, 3], "query": "Haystack"} + result = router.run(**kwargs) + assert result == {"enough_streams": [1, 2, 3]} + ``` + + In this example, we configure two routes. The first route sends the 'streams' value to 'enough_streams' if the + stream count exceeds two. Conversely, the second route directs 'streams' to 'insufficient_streams' when there + are two or fewer streams. + + In the pipeline setup, the router is connected to other components using the output names. For example, the + 'enough_streams' output might be connected to another component that processes the streams, while the + 'insufficient_streams' output might be connected to a component that fetches more streams, and so on. + + Here is a pseudocode example of a pipeline that uses the ConditionalRouter and routes fetched ByteStreams to + different components depending on the number of streams fetched: + + ``` + from typing import List + from haystack import Pipeline + from haystack.preview.dataclasses import ByteStream + from haystack.preview.components.routers import ConditionalRouter + + routes = [ + { + "condition": "{{streams|length > 2}}", + "output": "{{streams}}", + "output_name": "enough_streams", + "output_type": List[ByteStream], + }, + { + "condition": "{{streams|length <= 2}}", + "output": "{{streams}}", + "output_name": "insufficient_streams", + "output_type": List[ByteStream], + }, + ] + + pipe = Pipeline() + pipe.add_component("router", router) + ... + pipe.connect("router.enough_streams", "some_component_a.streams") + pipe.connect("router.insufficient_streams", "some_component_b.streams_or_some_other_input") + ... + ``` + """ + + def __init__(self, routes: List[Dict]): + """ + Initializes the ConditionalRouter with a list of routes detailing the conditions for routing. + + :param routes: A list of dictionaries, each defining a route with a boolean condition expression + ('condition'), an output value ('output'), the output type ('output_type') and + ('output_name') that defines the output name for the variable defined in 'output'. + """ + self._validate_routes(routes) + self.routes: List[dict] = routes + + # Create a Jinja native environment to inspect variables in the condition templates + env = NativeEnvironment() + + # Inspect the routes to determine input and output types. + input_types: Set[str] = set() # let's just store the name, type will always be Any + output_types: Dict[str, str] = {} + + for route in routes: + # extract inputs + route_input_names = self._extract_variables(env, [route["output"], route["condition"]]) + input_types.update(route_input_names) + + # extract outputs + output_types.update({route["output_name"]: route["output_type"]}) + + component.set_input_types(self, **{var: Any for var in input_types}) + component.set_output_types(self, **output_types) + + def to_dict(self) -> Dict[str, Any]: + for route in self.routes: + # output_type needs to be serialized to a string + route["output_type"] = serialize_type(route["output_type"]) + + return default_to_dict(self, routes=self.routes) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter": + init_params = data.get("init_parameters", {}) + routes = init_params.get("routes") + for route in routes: + # output_type needs to be deserialized from a string to a type + route["output_type"] = deserialize_type(route["output_type"]) + return default_from_dict(cls, data) + + def run(self, **kwargs): + """ + Executes the routing logic by evaluating the specified boolean condition expressions + for each route in the order they are listed. The method directs the flow + of data to the output specified in the first route, whose expression + evaluates to True. If no route's expression evaluates to True, an exception + is raised. + + :param kwargs: A dictionary containing the pipeline variables, which should + include all variables used in the "condition" templates. + + :return: A dictionary containing the output and the corresponding result, + based on the first route whose expression evaluates to True. + + :raises NoRouteSelectedException: If no route's expression evaluates to True. + """ + # Create a Jinja native environment to evaluate the condition templates as Python expressions + env = NativeEnvironment() + + for route in self.routes: + try: + t = env.from_string(route["condition"]) + if t.render(**kwargs): + # We now evaluate the `output` expression to determine the route output + t_output = env.from_string(route["output"]) + output = t_output.render(**kwargs) + # and return the output as a dictionary under the output_name key + return {route["output_name"]: output} + except Exception as e: + raise RouteConditionException(f"Error evaluating condition for route '{route}': {e}") from e + + raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}") + + def _validate_routes(self, routes: List[Dict]): + """ + Validates a list of routes. + + :param routes: A list of routes. + :type routes: List[Dict] + """ + env = NativeEnvironment() + for route in routes: + try: + keys = set(route.keys()) + except AttributeError: + raise ValueError(f"Route must be a dictionary, got: {route}") + + mandatory_fields = {"condition", "output", "output_type", "output_name"} + has_all_mandatory_fields = mandatory_fields.issubset(keys) + if not has_all_mandatory_fields: + raise ValueError( + f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}" + ) + for field in ["condition", "output"]: + if not self._validate_template(env, route[field]): + raise ValueError(f"Invalid template for field '{field}': {route[field]}") + + def _extract_variables(self, env: NativeEnvironment, templates: List[str]) -> Set[str]: + """ + Extracts all variables from a list of Jinja template strings. + + :param env: A Jinja environment. + :type env: Environment + :param templates: A list of Jinja template strings. + :type templates: List[str] + :return: A set of variable names. + """ + variables = set() + for template in templates: + ast = env.parse(template) + variables.update(meta.find_undeclared_variables(ast)) + return variables + + def _validate_template(self, env: Environment, template_text: str): + """ + Validates a template string by parsing it with Jinja. + + :param env: A Jinja environment. + :type env: Environment + :param template_text: A Jinja template string. + :type template_text: str + :return: True if the template is valid, False otherwise. + """ + try: + env.parse(template_text) + return True + except TemplateSyntaxError: + return False diff --git a/releasenotes/notes/add-router-f1f0cec79b1efe9a.yaml b/releasenotes/notes/add-router-f1f0cec79b1efe9a.yaml new file mode 100644 index 0000000000..7d6084f508 --- /dev/null +++ b/releasenotes/notes/add-router-f1f0cec79b1efe9a.yaml @@ -0,0 +1,6 @@ +--- +preview: + - | + Add `ConditionalRouter` component to enhance the conditional pipeline routing capabilities. + The `ConditionalRouter` component orchestrates the flow of data by evaluating specified route conditions + to determine the appropriate route among a set of provided route alternatives. diff --git a/test/preview/components/routers/test_conditional_router.py b/test/preview/components/routers/test_conditional_router.py new file mode 100644 index 0000000000..501fb530d1 --- /dev/null +++ b/test/preview/components/routers/test_conditional_router.py @@ -0,0 +1,324 @@ +import copy +import typing +from typing import List, Dict +from unittest import mock + +import pytest + +from haystack.preview.components.routers import ConditionalRouter +from haystack.preview.components.routers.conditional_router import ( + NoRouteSelectedException, + serialize_type, + deserialize_type, +) +from haystack.preview.dataclasses import ChatMessage + + +class TestRouter: + @pytest.fixture + def routes(self): + return [ + {"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"}, + { + "condition": "{{streams|length >= 2}}", + "output": "{{streams}}", + "output_type": List[int], + "output_name": "streams", + }, + ] + + @pytest.fixture + def router(self, routes): + return ConditionalRouter(routes) + + def test_missing_mandatory_fields(self): + """ + Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys + """ + routes = [ + {"condition": "{{streams|length < 2}}", "output": "{{query}}"}, + {"condition": "{{streams|length < 2}}", "output_type": str}, + ] + with pytest.raises(ValueError): + ConditionalRouter(routes) + + def test_invalid_condition_field(self): + """ + ConditionalRouter init raises a ValueError if one of the routes contains invalid condition + """ + # invalid condition field + routes = [{"condition": "{{streams|length < 2", "output": "query", "output_type": str, "output_name": "test"}] + with pytest.raises(ValueError, match="Invalid template"): + ConditionalRouter(routes) + + def test_no_vars_in_output_route_but_with_output_name(self): + """ + Router can't accept a route with no variables used in the output field + """ + routes = [ + { + "condition": "{{streams|length > 2}}", + "output": "This is a constant", + "output_name": "enough_streams", + "output_type": str, + } + ] + router = ConditionalRouter(routes) + kwargs = {"streams": [1, 2, 3], "query": "Haystack"} + result = router.run(**kwargs) + assert result == {"enough_streams": "This is a constant"} + + def test_mandatory_and_optional_fields_with_extra_fields(self): + """ + Router accepts a list of routes with mandatory and optional fields but not if some new field is added + """ + + routes = [ + { + "condition": "{{streams|length < 2}}", + "output": "{{query}}", + "output_type": str, + "output_name": "test", + "bla": "bla", + }, + {"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str}, + ] + + with pytest.raises(ValueError): + ConditionalRouter(routes) + + def test_router_initialized(self, routes): + router = ConditionalRouter(routes) + + assert router.routes == routes + assert set(router.__canals_input__.keys()) == {"query", "streams"} + assert set(router.__canals_output__.keys()) == {"query", "streams"} + + def test_router_evaluate_condition_expressions(self, router): + # first route should be selected + kwargs = {"streams": [1, 2, 3], "query": "test"} + result = router.run(**kwargs) + assert result == {"streams": [1, 2, 3]} + + # second route should be selected + kwargs = {"streams": [1], "query": "test"} + result = router.run(**kwargs) + assert result == {"query": "test"} + + def test_router_evaluate_condition_expressions_using_output_slot(self): + routes = [ + { + "condition": "{{streams|length > 2}}", + "output": "{{streams}}", + "output_name": "enough_streams", + "output_type": List[int], + }, + { + "condition": "{{streams|length <= 2}}", + "output": "{{streams}}", + "output_name": "insufficient_streams", + "output_type": List[int], + }, + ] + router = ConditionalRouter(routes) + # enough_streams output slot will be selected with [1, 2, 3] list being outputted + kwargs = {"streams": [1, 2, 3], "query": "Haystack"} + result = router.run(**kwargs) + assert result == {"enough_streams": [1, 2, 3]} + + def test_complex_condition(self): + routes = [ + { + "condition": "{{messages[-1].metadata.finish_reason == 'function_call'}}", + "output": "{{streams}}", + "output_type": List[int], + "output_name": "streams", + }, + { + "condition": "{{True}}", + "output": "{{query}}", + "output_type": str, + "output_name": "query", + }, # catch-all condition + ] + router = ConditionalRouter(routes) + message = mock.MagicMock() + message.metadata.finish_reason = "function_call" + result = router.run(messages=[message], streams=[1, 2, 3], query="my query") + assert result == {"streams": [1, 2, 3]} + + def test_router_no_route(self, router): + # should raise an exception + router = ConditionalRouter( + [ + { + "condition": "{{streams|length < 2}}", + "output": "{{query}}", + "output_type": str, + "output_name": "query", + }, + { + "condition": "{{streams|length >= 5}}", + "output": "{{streams}}", + "output_type": List[int], + "output_name": "streams", + }, + ] + ) + + kwargs = {"streams": [1, 2, 3], "query": "test"} + with pytest.raises(NoRouteSelectedException): + router.run(**kwargs) + + def test_router_raises_value_error_if_route_not_dictionary(self): + """ + Router raises a ValueError if each route is not a dictionary + """ + routes = [ + {"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"}, + ["{{streams|length >= 2}}", "streams", List[int]], + ] + + with pytest.raises(ValueError): + ConditionalRouter(routes) + + def test_router_raises_value_error_if_route_missing_keys(self): + """ + Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys + """ + routes = [ + {"condition": "{{streams|length < 2}}", "output": "{{query}}"}, + {"condition": "{{streams|length < 2}}", "output_type": str}, + ] + + with pytest.raises(ValueError): + ConditionalRouter(routes) + + def test_output_type_serialization(self): + assert serialize_type(str) == "str" + assert serialize_type(List[int]) == "typing.List[int]" + assert serialize_type(List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]" + assert serialize_type(ChatMessage) == "haystack.preview.dataclasses.chat_message.ChatMessage" + assert serialize_type(typing.List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]" + assert serialize_type(List[ChatMessage]) == "typing.List[haystack.preview.dataclasses.chat_message.ChatMessage]" + assert ( + serialize_type(typing.Dict[int, ChatMessage]) + == "typing.Dict[int, haystack.preview.dataclasses.chat_message.ChatMessage]" + ) + assert serialize_type(int) == "int" + assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.preview.dataclasses.chat_message.ChatMessage" + + def test_output_type_deserialization(self): + assert deserialize_type("str") == str + assert deserialize_type("typing.List[int]") == typing.List[int] + assert deserialize_type("typing.List[typing.Dict[str, int]]") == typing.List[Dict[str, int]] + assert deserialize_type("typing.Dict[str, int]") == Dict[str, int] + assert deserialize_type("typing.Dict[str, typing.List[int]]") == Dict[str, List[int]] + assert deserialize_type("typing.List[typing.Dict[str, typing.List[int]]]") == List[Dict[str, List[int]]] + assert ( + deserialize_type("typing.List[haystack.preview.dataclasses.chat_message.ChatMessage]") + == typing.List[ChatMessage] + ) + assert ( + deserialize_type("typing.Dict[int, haystack.preview.dataclasses.chat_message.ChatMessage]") + == typing.Dict[int, ChatMessage] + ) + assert deserialize_type("haystack.preview.dataclasses.chat_message.ChatMessage") == ChatMessage + assert deserialize_type("int") == int + + def test_router_de_serialization(self): + routes = [ + {"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"}, + { + "condition": "{{streams|length >= 2}}", + "output": "{{streams}}", + "output_type": List[int], + "output_name": "streams", + }, + ] + router = ConditionalRouter(routes) + router_dict = router.to_dict() + + # assert that the router dict is correct, with all keys and values being strings + for route in router_dict["init_parameters"]["routes"]: + for key in route.keys(): + assert isinstance(key, str) + assert isinstance(route[key], str) + + new_router = ConditionalRouter.from_dict(router_dict) + assert router.routes == new_router.routes + + # now use both routers with the same input + kwargs = {"streams": [1, 2, 3], "query": "Haystack"} + result1 = router.run(**kwargs) + result2 = new_router.run(**kwargs) + + # check that the result is the same and correct + assert result1 == result2 and result1 == {"streams": [1, 2, 3]} + + def test_router_de_serialization_user_type(self): + routes = [ + { + "condition": "{{streams|length < 2}}", + "output": "{{message}}", + "output_type": ChatMessage, + "output_name": "message", + }, + { + "condition": "{{streams|length >= 2}}", + "output": "{{streams}}", + "output_type": List[int], + "output_name": "streams", + }, + ] + router = ConditionalRouter(routes) + router_dict = router.to_dict() + + # assert that the router dict is correct, with all keys and values being strings + for route in router_dict["init_parameters"]["routes"]: + for key in route.keys(): + assert isinstance(key, str) + assert isinstance(route[key], str) + + # check that the output_type is a string and a proper class name + assert ( + router_dict["init_parameters"]["routes"][0]["output_type"] + == "haystack.preview.dataclasses.chat_message.ChatMessage" + ) + + # deserialize the router + new_router = ConditionalRouter.from_dict(router_dict) + + # check that the output_type is the right class + assert new_router.routes[0]["output_type"] == ChatMessage + assert router.routes == new_router.routes + + # now use both routers to run the same message + message = ChatMessage.from_user("ciao") + kwargs = {"streams": [1], "message": message} + result1 = router.run(**kwargs) + result2 = new_router.run(**kwargs) + + # check that the result is the same and correct + assert result1 == result2 and result1["message"].content == message.content + + def test_router_serialization_idempotence(self): + routes = [ + { + "condition": "{{streams|length < 2}}", + "output": "{{message}}", + "output_type": ChatMessage, + "output_name": "message", + }, + { + "condition": "{{streams|length >= 2}}", + "output": "{{streams}}", + "output_type": List[int], + "output_name": "streams", + }, + ] + router = ConditionalRouter(routes) + # invoke to_dict twice and check that the result is the same + router_dict_first_invocation = copy.deepcopy(router.to_dict()) + router_dict_second_invocation = router.to_dict() + assert router_dict_first_invocation == router_dict_second_invocation