From 820572439553ca2d2b51cb567f1f75e2527b67b9 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:43:16 +0100 Subject: [PATCH] feat: Rework `Pipeline.run()` to better handle cycles (#8431) * draft * Enhance * Almost works * Simplify some parts and handle intermediate outputs * Handle connections with default * Handle cycles with multiple connections from two components * Update distributed outputs at the correct time * Remove Component inputs after it runs * Add agent pipeline test case * Fix infite loop test * Handle some corner cases with loops checking and inputs deletion * Fix tests * Add new behavioral test * Remove unused code in behavioural test * Fix behavioural test * Fix max run check * Simplify outputs distribution * Simplify subgraph run check * Remove unused _init_run_queue function * Remove commented code * Add some missing type hints * Simplify cycles breaking * Fix _distribute_output test * Fix _find_components_that_will_receive_no_input test * Fix validation test * Fix tracer losing Component inputs * Fix some linting issues * Remove ignore pylint rule * Rename method that break cycles and make it raise * Add docstring to _run_subgraph * Update Pipeline.run() docstring * Update comment to clarify cycles execution * Remove SelfLoop sample Component * Add behavioural test for unsupported cycles * Rename behavioural test to be more specific * Add new behavioural test * Add release notes * Remove commented out code and random pass * Use more efficient function to find cycles * Simplify _break_supported_cycles_in_graph by using defaultdict * Stop breaking edges as soon as we make the graph acyclic * Fix docstring and add some more comments * Fix _distribute_output docstring * Fix _find_receivers_from docstring * More detailed release notes * Minimize calls to networkx.is_directed_acyclic_graph * Add some more info on edges keys * Adjust components_in_cycles comment * Add new Pipeline behavioural test * Enhance _find_components_that_will_receive_no_input to cover more cases * Explain why run_queue is reset after running a subgraph cycle * Rename _init_inputs_state to _normalize_input_data * Better explain the subgraph output distribution * Remove for else * Fix some comments and docstrings * Fix linting * Add missing return type * Fix typo * Rename _normalize_input_data to _normalize_varidiac_input_data and add more documentation * Remove unused import --------- Co-authored-by: Sebastian Husch Lee --- haystack/core/pipeline/base.py | 241 +++++-- haystack/core/pipeline/pipeline.py | 348 ++++++++-- .../testing/sample_components/__init__.py | 2 - .../testing/sample_components/self_loop.py | 27 - .../pipeline-run-rework-23a972d83b792db2.yaml | 12 + .../pipeline/features/pipeline_run.feature | 7 +- test/core/pipeline/features/test_run.py | 604 +++++++++++++++++- test/core/pipeline/test_pipeline.py | 147 +++-- .../pipeline/test_validation_pipeline_io.py | 16 +- 9 files changed, 1193 insertions(+), 211 deletions(-) delete mode 100644 haystack/testing/sample_components/self_loop.py create mode 100644 releasenotes/notes/pipeline-run-rework-23a972d83b792db2.yaml diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 3d1b1c0bba..31ad2ad93c 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -5,7 +5,7 @@ import importlib import itertools from collections import defaultdict -from copy import copy, deepcopy +from copy import deepcopy from datetime import datetime from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Set, TextIO, Tuple, Type, TypeVar, Union @@ -19,6 +19,7 @@ PipelineConnectError, PipelineDrawingError, PipelineError, + PipelineRuntimeError, PipelineUnmarshalError, PipelineValidationError, ) @@ -765,7 +766,10 @@ def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[ return data - def _init_inputs_state(self, data: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + def _normalize_varidiac_input_data(self, data: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + """ + Variadic inputs expect their value to be a list, this utility method creates that list from the user's input. + """ for component_name, component_inputs in data.items(): if component_name not in self.graph.nodes: # This is not a component name, it must be the name of one or more input sockets. @@ -773,8 +777,6 @@ def _init_inputs_state(self, data: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[ continue instance = self.graph.nodes[component_name]["instance"] for component_input, input_value in component_inputs.items(): - # Handle mutable input data - data[component_name][component_input] = copy(input_value) if instance.__haystack_input__._sockets_dict[component_input].is_variadic: # Components that have variadic inputs need to receive lists as input. # We don't want to force the user to always pass lists, so we convert single values to lists here. @@ -784,41 +786,6 @@ def _init_inputs_state(self, data: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[ return {**data} - def _init_run_queue(self, pipeline_inputs: Dict[str, Any]) -> List[Tuple[str, Component]]: - run_queue: List[Tuple[str, Component]] = [] - - # HACK: Quick workaround for the issue of execution order not being - # well-defined (NB - https://github.com/deepset-ai/haystack/issues/7985). - # We should fix the original execution logic instead. - if networkx.is_directed_acyclic_graph(self.graph): - # If the Pipeline is linear we can easily determine the order of execution with - # a topological sort. - # So use that to get the run order. - for node in networkx.topological_sort(self.graph): - run_queue.append((node, self.graph.nodes[node]["instance"])) - return run_queue - - for node_name in self.graph.nodes: - component = self.graph.nodes[node_name]["instance"] - - if len(component.__haystack_input__._sockets_dict) == 0: - # Component has no input, can run right away - run_queue.append((node_name, component)) - continue - - if node_name in pipeline_inputs: - # This component is in the input data, if it has enough inputs it can run right away - run_queue.append((node_name, component)) - continue - - for socket in component.__haystack_input__._sockets_dict.values(): - if not socket.senders or socket.is_variadic: - # Component has at least one input not connected or is variadic, can run right away. - run_queue.append((node_name, component)) - break - - return run_queue - @classmethod def from_template( cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None @@ -851,9 +818,27 @@ def _init_graph(self): for node in self.graph.nodes: self.graph.nodes[node]["visits"] = 0 - def _distribute_output( + def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]: + """ + Utility function to find all Components that receive input form `component_name`. + + :param component_name: + Name of the sender Component + + :returns: + List of tuples containing name of the receiver Component and sender OutputSocket + and receiver InputSocket instances + """ + res = [] + for _, receiver_name, connection in self.graph.edges(nbunch=component_name, data=True): + sender_socket: OutputSocket = connection["from_socket"] + receiver_socket: InputSocket = connection["to_socket"] + res.append((receiver_name, sender_socket, receiver_socket)) + return res + + def _distribute_output( # pylint: disable=too-many-positional-arguments self, - component_name: str, + receiver_components: List[Tuple[str, OutputSocket, InputSocket]], component_result: Dict[str, Any], components_inputs: Dict[str, Dict[str, Any]], run_queue: List[Tuple[str, Component]], @@ -865,23 +850,27 @@ def _distribute_output( This also updates the queues that keep track of which Components are ready to run and which are waiting for input. - :param component_name: Name of the Component that created the output - :param component_result: The output of the Component - :paramt components_inputs: The current state of the inputs divided by Component name - :param run_queue: Queue of Components to run - :param waiting_queue: Queue of Components waiting for input + :param receiver_components: + List of tuples containing name of receiver Components and relative sender OutputSocket + and receiver InputSocket instances + :param component_result: + The output of the Component + :param components_inputs: + The current state of the inputs divided by Component name + :param run_queue: + Queue of Components to run + :param waiting_queue: + Queue of Components waiting for input - :returns: The updated output of the Component without the keys that were distributed to other Components + :returns: + The updated output of the Component without the keys that were distributed to other Components """ # We keep track of which keys to remove from component_result at the end of the loop. # This is done after the output has been distributed to the next components, so that # we're sure all components that need this output have received it. to_remove_from_component_result = set() - for _, receiver_name, connection in self.graph.edges(nbunch=component_name, data=True): - sender_socket: OutputSocket = connection["from_socket"] - receiver_socket: InputSocket = connection["to_socket"] - + for receiver_name, sender_socket, receiver_socket in receiver_components: if sender_socket.name not in component_result: # This output wasn't created by the sender, nothing we can do. # @@ -929,7 +918,7 @@ def _distribute_output( run_queue.remove(pair) if pair in waiting_queue: waiting_queue.remove(pair) - run_queue.append(pair) + run_queue.insert(0, pair) else: # If the receiver Component has a variadic input that is not greedy # we put it in the waiting queue. @@ -1048,16 +1037,33 @@ def _find_components_that_will_receive_no_input( """ # Simplifies the check if a Component is Variadic and received some input from other Components. - def is_variadic_with_existing_inputs(comp: Component) -> bool: - for receiver_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore - if component_name not in receiver_socket.senders: + def has_variadic_socket_with_existing_inputs( + component: Component, component_name: str, sender_name: str, components_inputs: Dict[str, Dict[str, Any]] + ) -> bool: + for socket in component.__haystack_input__._sockets_dict.values(): # type: ignore + if sender_name not in socket.senders: continue - if ( - receiver_socket.is_variadic - and len(components_inputs.get(receiver, {}).get(receiver_socket.name, [])) > 0 - ): - # This Component already received some input to its Variadic socket from other Components. - # It should be able to run even if it doesn't receive any input from component_name. + if socket.is_variadic and len(components_inputs.get(component_name, {}).get(socket.name, [])) > 0: + return True + return False + + # Makes it easier to verify if all connections between two Components are optional + def all_connections_are_optional(sender_name: str, receiver: Component) -> bool: + for socket in receiver.__haystack_input__._sockets_dict.values(): # type: ignore + if sender_name not in socket.senders: + continue + if socket.is_mandatory: + return False + return True + + # Eases checking if other connections that are not between sender_name and receiver_name + # already received inputs + def other_connections_received_input(sender_name: str, receiver_name: str) -> bool: + receiver: Component = self.graph.nodes[receiver_name]["instance"] + for receiver_socket in receiver.__haystack_input__._sockets_dict.values(): # type: ignore + if sender_name in receiver_socket.senders: + continue + if components_inputs.get(receiver_name, {}).get(receiver_socket.name) is not None: return True return False @@ -1069,7 +1075,21 @@ def is_variadic_with_existing_inputs(comp: Component) -> bool: for receiver in socket.receivers: receiver_instance: Component = self.graph.nodes[receiver]["instance"] - if is_variadic_with_existing_inputs(receiver_instance): + if has_variadic_socket_with_existing_inputs( + receiver_instance, receiver, component_name, components_inputs + ): + # Components with Variadic input that already received some input + # can still run, even if branch is skipped. + # If we remove them they won't run. + continue + + if all_connections_are_optional(component_name, receiver_instance) and other_connections_received_input( + component_name, receiver + ): + # If all the connections between component_name and receiver are optional + # and receiver received other inputs already it still has enough inputs to run. + # Even if it didn't receive input from component_name, so we can't remove it or its + # descendants. continue components.add((receiver, receiver_instance)) @@ -1078,7 +1098,18 @@ def is_variadic_with_existing_inputs(comp: Component) -> bool: # This is fine even if the Pipeline will merge back into a single Component # at a certain point. The merging Component will be put back into the run # queue at a later stage. - components |= {(d, self.graph.nodes[d]["instance"]) for d in networkx.descendants(self.graph, receiver)} + for descendant_name in networkx.descendants(self.graph, receiver): + descendant = self.graph.nodes[descendant_name]["instance"] + + # Components with Variadic input that already received some input + # can still run, even if branch is skipped. + # If we remove them they won't run. + if has_variadic_socket_with_existing_inputs( + descendant, descendant_name, receiver, components_inputs + ): + continue + + components.add((descendant_name, descendant)) return components @@ -1127,6 +1158,90 @@ def _component_has_enough_inputs_to_run(self, name: str, inputs: Dict[str, Dict[ current_inputs = inputs[name].keys() return expected_inputs == current_inputs + def _break_supported_cycles_in_graph(self) -> Tuple[networkx.MultiDiGraph, Dict[str, List[List[str]]]]: + """ + Utility function to remove supported cycles in the Pipeline's graph. + + Given that the Pipeline execution would wait to run a Component until it has received + all its mandatory inputs, it doesn't make sense for us to try and break cycles by + removing a connection to a mandatory input. The Pipeline would just get stuck at a later time. + + So we can only break connections in cycles that have a Variadic or GreedyVariadic type or a default value. + + This will raise a PipelineRuntimeError if we there are cycles that can't be broken. + That is bound to happen when at least one of the inputs in a cycle is mandatory. + + If the Pipeline's graph doesn't have any cycle it will just return that graph and an empty dictionary. + + :returns: + A tuple containing: + * A copy of the Pipeline's graph without cycles + * A dictionary of Component's names and a list of all the cycles they were part of. + The cycles are a list of Component's names that create that cycle. + """ + if networkx.is_directed_acyclic_graph(self.graph): + return self.graph, {} + + temp_graph: networkx.MultiDiGraph = self.graph.copy() + # A list of all the cycles that are found in the graph, each inner list contains + # the Component names that create that cycle. + cycles: List[List[str]] = list(networkx.simple_cycles(self.graph)) + # Maps a Component name to a list of its output socket names that have been broken + edges_removed: Dict[str, List[str]] = defaultdict(list) + # This keeps track of all the cycles that a component is part of. + # Maps a Component name to a list of cycles, each inner list contains + # the Component names that create that cycle (the key will also be + # an element in each list). The last Component in each list is implicitly + # connected to the first. + components_in_cycles: Dict[str, List[List[str]]] = defaultdict(list) + + # Used to minimize the number of time we check whether the graph has any more + # cycles left to break or not. + graph_has_cycles = True + + # Iterate all the cycles to find the least amount of connections that we can remove + # to make the Pipeline graph acyclic. + # As soon as the graph is acyclic we stop breaking connections and return. + for cycle in cycles: + for comp in cycle: + components_in_cycles[comp].append(cycle) + + # Iterate this cycle, we zip the cycle with itself so that at the last iteration + # sender_comp will be the last element of cycle and receiver_comp will be the first. + # So if cycle is [1, 2, 3, 4] we would call zip([1, 2, 3, 4], [2, 3, 4, 1]). + for sender_comp, receiver_comp in zip(cycle, cycle[1:] + cycle[:1]): + # We get the key and iterate those as we want to edit the graph data while + # iterating the edges and that would raise. + # Even though the connection key set in Pipeline.connect() uses only the + # sockets name we don't have clashes since it's only used to differentiate + # multiple edges between two nodes. + edge_keys = list(temp_graph.get_edge_data(sender_comp, receiver_comp).keys()) + for edge_key in edge_keys: + edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)[edge_key] + receiver_socket = edge_data["to_socket"] + if not receiver_socket.is_variadic and receiver_socket.is_mandatory: + continue + + # We found a breakable edge + sender_socket = edge_data["from_socket"] + edges_removed[sender_comp].append(sender_socket.name) + temp_graph.remove_edge(sender_comp, receiver_comp, edge_key) + + graph_has_cycles = not networkx.is_directed_acyclic_graph(temp_graph) + if not graph_has_cycles: + # We removed all the cycles, we can stop + break + + if not graph_has_cycles: + # We removed all the cycles, nice + break + + if graph_has_cycles: + msg = "Pipeline contains a cycle that we can't execute" + raise PipelineRuntimeError(msg) + + return temp_graph, components_in_cycles + def _connections_status( sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket] diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index ce3b1c1f63..7f8c2cfbed 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -6,6 +6,8 @@ from typing import Any, Dict, List, Mapping, Optional, Set, Tuple from warnings import warn +import networkx as nx + from haystack import logging, tracing from haystack.core.component import Component from haystack.core.errors import PipelineMaxComponentRuns, PipelineRuntimeError @@ -62,7 +64,9 @@ def _run_component(self, name: str, inputs: Dict[str, Any]) -> Dict[str, Any]: }, }, ) as span: - span.set_content_tag("haystack.component.input", inputs) + # We deepcopy the inputs otherwise we might lose that information + # when we delete them in case they're sent to other Components + span.set_content_tag("haystack.component.input", deepcopy(inputs)) logger.info("Running component {component_name}", component_name=name) res: Dict[str, Any] = instance.run(**inputs) self.graph.nodes[name]["visits"] += 1 @@ -84,11 +88,225 @@ def _run_component(self, name: str, inputs: Dict[str, Any]) -> Dict[str, Any]: return res - def run( # noqa: PLR0915 + def _run_subgraph( # noqa: PLR0915 + self, + cycle: List[str], + component_name: str, + components_inputs: Dict[str, Dict[str, Any]], + include_outputs_from: Optional[Set[str]] = None, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Runs a `cycle` in the Pipeline starting from `component_name`. + + This will return once there are no inputs for the Components in `cycle`. + + This is an internal method meant to be used in `Pipeline.run()` only. + + :param cycle: + List of Components that are part of the cycle being run + :param component_name: + Name of the Component that will start execution of the cycle + :param components_inputs: + Components inputs, this might include inputs for Components that are not part + of the cycle but part of the wider Pipeline's graph + :param include_outputs_from: + Set of component names whose individual outputs are to be + included in the cycle's output. In case a Component is executed multiple times + only the last-produced output is included. + :returns: + Outputs of all the Components that are not connected to other Components in `cycle`. + If `include_outputs_from` is set those Components' outputs will be included. + :raises PipelineMaxComponentRuns: + If a Component reaches the maximum number of times it can be run in this Pipeline + """ + waiting_queue: List[Tuple[str, Component]] = [] + run_queue: List[Tuple[str, Component]] = [] + + # Create the run queue starting with the component that needs to run first + start_index = cycle.index(component_name) + for node in cycle[start_index:]: + run_queue.append((node, self.graph.nodes[node]["instance"])) + + include_outputs_from = set() if include_outputs_from is None else include_outputs_from + + before_last_waiting_queue: Optional[Set[str]] = None + last_waiting_queue: Optional[Set[str]] = None + + subgraph_outputs = {} + # These are outputs that are sent to other Components but the user explicitly + # asked to include them in the final output. + extra_outputs = {} + + # This variable is used to keep track if we still need to run the cycle or not. + # When a Component doesn't send outputs to another Component + # that's inside the subgraph, we stop running this subgraph. + cycle_received_inputs = False + + while not cycle_received_inputs: + # Here we run the Components + name, comp = run_queue.pop(0) + if _is_lazy_variadic(comp) and not all(_is_lazy_variadic(comp) for _, comp in run_queue): + # We run Components with lazy variadic inputs only if there only Components with + # lazy variadic inputs left to run + _enqueue_waiting_component((name, comp), waiting_queue) + continue + + # As soon as a Component returns only output that is not part of the cycle, we can stop + if self._component_has_enough_inputs_to_run(name, components_inputs): + if self.graph.nodes[name]["visits"] > self._max_runs_per_component: + msg = f"Maximum run count {self._max_runs_per_component} reached for component '{name}'" + raise PipelineMaxComponentRuns(msg) + + res: Dict[str, Any] = self._run_component(name, components_inputs[name]) + + # Delete the inputs that were consumed by the Component and are not received from + # the user or from Components that are part of this cycle + sockets = list(components_inputs[name].keys()) + for socket_name in sockets: + senders = comp.__haystack_input__._sockets_dict[socket_name].senders # type: ignore + if not senders: + # We keep inputs that came from the user + continue + all_senders_in_cycle = all(sender in cycle for sender in senders) + if all_senders_in_cycle: + # All senders are in the cycle, we can remove the input. + # We'll receive it later at a certain point. + del components_inputs[name][socket_name] + + if name in include_outputs_from: + # Deepcopy the outputs to prevent downstream nodes from modifying them + # We don't care about loops - Always store the last output. + extra_outputs[name] = deepcopy(res) + + # Reset the waiting for input previous states, we managed to run a component + before_last_waiting_queue = None + last_waiting_queue = None + + # Check if a component doesn't send any output to components that are part of the cycle + final_output_reached = False + for output_socket in res.keys(): + for receiver in comp.__haystack_output__._sockets_dict[output_socket].receivers: # type: ignore + if receiver in cycle: + final_output_reached = True + break + if final_output_reached: + break + + if not final_output_reached: + # We stop only if the Component we just ran doesn't send any output to sockets that + # are part of the cycle + cycle_received_inputs = True + + # We manage to run this component that was in the waiting list, we can remove it. + # This happens when a component was put in the waiting list but we reached it from another edge. + _dequeue_waiting_component((name, comp), waiting_queue) + for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs): + _dequeue_component(pair, run_queue, waiting_queue) + + receivers = [item for item in self._find_receivers_from(name) if item[0] in cycle] + + res = self._distribute_output(receivers, res, components_inputs, run_queue, waiting_queue) + + # We treat a cycle as a completely independent graph, so we keep track of output + # that is not sent inside the cycle. + # This output is going to get distributed to the wider graph after we finish running + # a cycle. + # All values that are left at this point go outside the cycle. + if len(res) > 0: + subgraph_outputs[name] = res + else: + # This component doesn't have enough inputs so we can't run it yet + _enqueue_waiting_component((name, comp), waiting_queue) + + if len(run_queue) == 0 and len(waiting_queue) > 0: + # Check if we're stuck in a loop. + # It's important to check whether previous waitings are None as it could be that no + # Component has actually been run yet. + if ( + before_last_waiting_queue is not None + and last_waiting_queue is not None + and before_last_waiting_queue == last_waiting_queue + ): + if self._is_stuck_in_a_loop(waiting_queue): + # We're stuck! We can't make any progress. + msg = ( + "Pipeline is stuck running in a loop. Partial outputs will be returned. " + "Check the Pipeline graph for possible issues." + ) + warn(RuntimeWarning(msg)) + break + + (name, comp) = self._find_next_runnable_lazy_variadic_or_default_component(waiting_queue) + _add_missing_input_defaults(name, comp, components_inputs) + _enqueue_component((name, comp), run_queue, waiting_queue) + continue + + before_last_waiting_queue = last_waiting_queue.copy() if last_waiting_queue is not None else None + last_waiting_queue = {item[0] for item in waiting_queue} + + (name, comp) = self._find_next_runnable_component(components_inputs, waiting_queue) + _add_missing_input_defaults(name, comp, components_inputs) + _enqueue_component((name, comp), run_queue, waiting_queue) + + return subgraph_outputs, extra_outputs + + def run( # noqa: PLR0915, PLR0912 self, data: Dict[str, Any], include_outputs_from: Optional[Set[str]] = None ) -> Dict[str, Any]: """ - Runs the pipeline with given input data. + Runs the Pipeline with given input data. + + Usage: + ```python + from haystack import Pipeline, Document + from haystack.utils import Secret + from haystack.document_stores.in_memory import InMemoryDocumentStore + from haystack.components.retrievers.in_memory import InMemoryBM25Retriever + from haystack.components.generators import OpenAIGenerator + from haystack.components.builders.answer_builder import AnswerBuilder + from haystack.components.builders.prompt_builder import PromptBuilder + + # Write documents to InMemoryDocumentStore + document_store = InMemoryDocumentStore() + document_store.write_documents([ + Document(content="My name is Jean and I live in Paris."), + Document(content="My name is Mark and I live in Berlin."), + Document(content="My name is Giorgio and I live in Rome.") + ]) + + prompt_template = \"\"\" + Given these documents, answer the question. + Documents: + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + Question: {{question}} + Answer: + \"\"\" + + retriever = InMemoryBM25Retriever(document_store=document_store) + prompt_builder = PromptBuilder(template=prompt_template) + llm = OpenAIGenerator(api_key=Secret.from_token(api_key)) + + rag_pipeline = Pipeline() + rag_pipeline.add_component("retriever", retriever) + rag_pipeline.add_component("prompt_builder", prompt_builder) + rag_pipeline.add_component("llm", llm) + rag_pipeline.connect("retriever", "prompt_builder.documents") + rag_pipeline.connect("prompt_builder", "llm") + + # Ask a question + question = "Who lives in Paris?" + results = rag_pipeline.run( + { + "retriever": {"query": question}, + "prompt_builder": {"question": question}, + } + ) + + print(results["llm"]["replies"]) + # Jean lives in Paris + ``` :param data: A dictionary of inputs for the pipeline's components. Each key is a component name @@ -104,7 +322,6 @@ def run( # noqa: PLR0915 "input1": 1, "input2": 2, } ``` - :param include_outputs_from: Set of component names whose individual outputs are to be included in the pipeline's output. For components that are @@ -117,41 +334,11 @@ def run( # noqa: PLR0915 without outgoing connections. :raises PipelineRuntimeError: - If a component fails or returns unexpected output. - - Example a - Using named components: - Consider a 'Hello' component that takes a 'word' input and outputs a greeting. - - ```python - @component - class Hello: - @component.output_types(output=str) - def run(self, word: str): - return {"output": f"Hello, {word}!"} - ``` - - Create a pipeline with two 'Hello' components connected together: - - ```python - pipeline = Pipeline() - pipeline.add_component("hello", Hello()) - pipeline.add_component("hello2", Hello()) - pipeline.connect("hello.output", "hello2.word") - result = pipeline.run(data={"hello": {"word": "world"}}) - ``` - - This runs the pipeline with the specified input for 'hello', yielding - {'hello2': {'output': 'Hello, Hello, world!!'}}. - - Example b - Using flat inputs: - You can also pass inputs directly without specifying component names: - - ```python - result = pipeline.run(data={"word": "world"}) - ``` - - The pipeline resolves inputs to the correct components, returning - {'hello2': {'output': 'Hello, Hello, world!!'}}. + If the Pipeline contains cycles with unsupported connections that would cause + it to get stuck and fail running. + Or if a Component fails or returns output in an unsupported type. + :raises PipelineMaxComponentRuns: + If a Component reaches the maximum number of times it can be run in this Pipeline. """ pipeline_running(self) @@ -168,15 +355,8 @@ def run(self, word: str): # Raise if input is malformed in some way self._validate_input(data) - # Initialize the inputs state - components_inputs: Dict[str, Dict[str, Any]] = self._init_inputs_state(data) - - # Take all components that: - # - have no inputs - # - receive input from the user - # - have at least one input not connected - # - have at least one input that is variadic - run_queue: List[Tuple[str, Component]] = self._init_run_queue(data) + # Normalize the input data + components_inputs: Dict[str, Dict[str, Any]] = self._normalize_varidiac_input_data(data) # These variables are used to detect when we're stuck in a loop. # Stuck loops can happen when one or more components are waiting for input but @@ -199,6 +379,31 @@ def run(self, word: str): # This is what we'll return at the end final_outputs: Dict[Any, Any] = {} + # Break cycles in case there are, this is a noop if no cycle is found. + # This will raise if a cycle can't be broken. + graph_without_cycles, components_in_cycles = self._break_supported_cycles_in_graph() + + run_queue: List[Tuple[str, Component]] = [] + for node in nx.topological_sort(graph_without_cycles): + run_queue.append((node, self.graph.nodes[node]["instance"])) + + # Set defaults inputs for those sockets that don't receive input neither from the user + # nor from other Components. + # If they have no default nothing is done. + # This is important to ensure correct order execution, otherwise some variadic + # Components that receive input from the user might be run before than they should. + for name, comp in self.graph.nodes(data="instance"): + if name not in components_inputs: + components_inputs[name] = {} + for socket_name, socket in comp.__haystack_input__._sockets_dict.items(): + if socket_name in components_inputs[name]: + continue + if not socket.senders: + value = socket.default_value + if socket.is_variadic: + value = [value] + components_inputs[name][socket_name] = value + with tracing.tracer.trace( "haystack.pipeline.run", tags={ @@ -219,14 +424,56 @@ def run(self, word: str): # lazy variadic inputs left to run _enqueue_waiting_component((name, comp), waiting_queue) continue + if self._component_has_enough_inputs_to_run(name, components_inputs) and components_in_cycles.get( + name, [] + ): + cycles = components_in_cycles.get(name, []) + + # This component is part of one or more cycles, let's get the first one and run it. + # We can reliably pick any of the cycles if there are multiple ones, the way cycles + # are run doesn't make a different whether we pick the first or any of the others a + # Component is part of. + subgraph_output, subgraph_extra_output = self._run_subgraph( + cycles[0], name, components_inputs, include_outputs_from + ) + + # After a cycle is run the previous run_queue can't be correct anymore cause it's + # not modified when running the subgraph. + # So we reset it given the output returned by the subgraph. + run_queue = [] + + # Reset the waiting for input previous states, we managed to run at least one component + before_last_waiting_queue = None + last_waiting_queue = None + + # Merge the extra outputs + extra_outputs.update(subgraph_extra_output) - if self._component_has_enough_inputs_to_run(name, components_inputs): + for component_name, component_output in subgraph_output.items(): + receivers = self._find_receivers_from(component_name) + component_output = self._distribute_output( + receivers, component_output, components_inputs, run_queue, waiting_queue + ) + + if len(component_output) > 0: + final_outputs[component_name] = component_output + + elif self._component_has_enough_inputs_to_run(name, components_inputs): if self.graph.nodes[name]["visits"] > self._max_runs_per_component: msg = f"Maximum run count {self._max_runs_per_component} reached for component '{name}'" raise PipelineMaxComponentRuns(msg) res: Dict[str, Any] = self._run_component(name, components_inputs[name]) + # Delete the inputs that were consumed by the Component and are not received from the user + sockets = list(components_inputs[name].keys()) + for socket_name in sockets: + senders = comp.__haystack_input__._sockets_dict[socket_name].senders + if senders: + # Delete all inputs that are received from other Components + del components_inputs[name][socket_name] + # We keep inputs that came from the user + if name in include_outputs_from: # Deepcopy the outputs to prevent downstream nodes from modifying them # We don't care about loops - Always store the last output. @@ -242,7 +489,8 @@ def run(self, word: str): for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs): _dequeue_component(pair, run_queue, waiting_queue) - res = self._distribute_output(name, res, components_inputs, run_queue, waiting_queue) + receivers = self._find_receivers_from(name) + res = self._distribute_output(receivers, res, components_inputs, run_queue, waiting_queue) if len(res) > 0: final_outputs[name] = res diff --git a/haystack/testing/sample_components/__init__.py b/haystack/testing/sample_components/__init__.py index 1a42faf7a3..011ca2ddea 100644 --- a/haystack/testing/sample_components/__init__.py +++ b/haystack/testing/sample_components/__init__.py @@ -13,7 +13,6 @@ from haystack.testing.sample_components.parity import Parity from haystack.testing.sample_components.remainder import Remainder from haystack.testing.sample_components.repeat import Repeat -from haystack.testing.sample_components.self_loop import SelfLoop from haystack.testing.sample_components.subtract import Subtract from haystack.testing.sample_components.sum import Sum from haystack.testing.sample_components.text_splitter import TextSplitter @@ -35,6 +34,5 @@ "Hello", "TextSplitter", "StringListJoiner", - "SelfLoop", "FString", ] diff --git a/haystack/testing/sample_components/self_loop.py b/haystack/testing/sample_components/self_loop.py deleted file mode 100644 index b29962ea89..0000000000 --- a/haystack/testing/sample_components/self_loop.py +++ /dev/null @@ -1,27 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -from haystack.core.component import component -from haystack.core.component.types import Variadic - - -@component -class SelfLoop: - """ - Decreases the initial value in steps of 1 until the target value is reached. - - For no good reason it uses a self-loop to do so :) - """ - - def __init__(self, target: int = 0): - self.target = target - - @component.output_types(current_value=int, final_result=int) - def run(self, values: Variadic[int]): - """Decreases the input value in steps of 1 until the target value is reached.""" - value = values[0] # type: ignore - value -= 1 - if value == self.target: - return {"final_result": value} - return {"current_value": value} diff --git a/releasenotes/notes/pipeline-run-rework-23a972d83b792db2.yaml b/releasenotes/notes/pipeline-run-rework-23a972d83b792db2.yaml new file mode 100644 index 0000000000..2a1fc7d6c7 --- /dev/null +++ b/releasenotes/notes/pipeline-run-rework-23a972d83b792db2.yaml @@ -0,0 +1,12 @@ +--- +highlights: > + `Pipeline.run()` internal logic has been heavily reworked to be more robust and reliable + than before. + This new implementation makes it easier to run `Pipeline`s that have cycles in their graph. + It also fixes some corner cases in `Pipeline`s that don't have any cycle. +features: + - | + Fundamentally rework the internal logic of `Pipeline.run()`. + The rework makes it more reliable and covers more use cases. + We fixed some issues that made `Pipeline`s with cycles unpredictable + and with unclear Components execution order. diff --git a/test/core/pipeline/features/pipeline_run.feature b/test/core/pipeline/features/pipeline_run.feature index bdb6d8f30d..e05f16b570 100644 --- a/test/core/pipeline/features/pipeline_run.feature +++ b/test/core/pipeline/features/pipeline_run.feature @@ -26,7 +26,7 @@ Feature: Pipeline running | that has a greedy and variadic component after a component with default input | | that has components added in a different order from the order of execution | | that has a component with only default inputs | - | that has a component with only default inputs as first to run | + | that has a component with only default inputs as first to run and receives inputs from a loop | | that has multiple branches that merge into a component with a single variadic input | | that has multiple branches of different lengths that merge into a component with a single variadic input | | that is linear and returns intermediate outputs | @@ -37,8 +37,12 @@ Feature: Pipeline running | that has a loop and a component with default inputs that doesn't receive anything from its sender but receives input from user | | that has multiple components with only default inputs and are added in a different order from the order of execution | | that is linear with conditional branching and multiple joins | + | that is a simple agent | | that has a variadic component that receives partial inputs | | that has an answer joiner variadic component | + | that is linear and a component in the middle receives optional input from other components and input from the user | + | that has a loop in the middle | + | that has variadic component that receives a conditional input | Scenario Outline: Running a bad Pipeline Given a pipeline @@ -49,3 +53,4 @@ Feature: Pipeline running | kind | exception | | that has an infinite loop | PipelineMaxComponentRuns | | that has a component that doesn't return a dictionary | PipelineRuntimeError | + | that has a cycle that would get it stuck | PipelineRuntimeError | diff --git a/test/core/pipeline/features/test_run.py b/test/core/pipeline/features/test_run.py index 36c9d84434..f5739aa690 100644 --- a/test/core/pipeline/features/test_run.py +++ b/test/core/pipeline/features/test_run.py @@ -1,12 +1,16 @@ +import json from typing import List, Optional, Dict, Any +import re from pytest_bdd import scenarios, given import pytest from haystack import Pipeline, Document, component +from haystack.document_stores.types import DuplicatePolicy from haystack.dataclasses import ChatMessage, GeneratedAnswer from haystack.components.routers import ConditionalRouter -from haystack.components.builders import PromptBuilder, AnswerBuilder +from haystack.components.builders import PromptBuilder, AnswerBuilder, ChatPromptBuilder +from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter from haystack.components.retrievers.in_memory import InMemoryBM25Retriever from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.components.joiners import BranchJoiner, DocumentJoiner, AnswerJoiner @@ -25,7 +29,6 @@ Hello, TextSplitter, StringListJoiner, - SelfLoop, ) from haystack.testing.factory import component_class @@ -67,18 +70,25 @@ def pipeline_that_is_linear(): @given("a pipeline that has an infinite loop", target_fixture="pipeline_data") def pipeline_that_has_an_infinite_loop(): - def custom_init(self): - component.set_input_type(self, "x", int) - component.set_input_type(self, "y", int, 1) - component.set_output_types(self, a=int, b=int) + routes = [ + {"condition": "{{number > 2}}", "output": "{{number}}", "output_name": "big_number", "output_type": int}, + {"condition": "{{number <= 2}}", "output": "{{number + 2}}", "output_name": "small_number", "output_type": int}, + ] + + main_input = BranchJoiner(int) + first_router = ConditionalRouter(routes=routes) + second_router = ConditionalRouter(routes=routes) - FakeComponent = component_class("FakeComponent", output={"a": 1, "b": 1}, extra_fields={"__init__": custom_init}) pipe = Pipeline(max_runs_per_component=1) - pipe.add_component("first", FakeComponent()) - pipe.add_component("second", FakeComponent()) - pipe.connect("first.a", "second.x") - pipe.connect("second.b", "first.y") - return pipe, [PipelineRunData({"first": {"x": 1}})] + pipe.add_component("main_input", main_input) + pipe.add_component("first_router", first_router) + pipe.add_component("second_router", second_router) + + pipe.connect("main_input", "first_router.number") + pipe.connect("first_router.big_number", "second_router.number") + pipe.connect("second_router.big_number", "main_input") + + return pipe, [PipelineRunData({"main_input": {"value": 3}})] @given("a pipeline that is really complex with lots of components, forks, and loops", target_fixture="pipeline_data") @@ -146,8 +156,11 @@ def pipeline_complex(): expected_outputs={"accumulate_3": {"value": -7}, "add_five": {"result": -6}}, expected_run_order=[ "greet_first", + "greet_enumerator", "accumulate_1", + "enumerate", "add_two", + "add_three", "parity", "add_one", "branch_joiner", @@ -159,9 +172,6 @@ def pipeline_complex(): "branch_joiner", "below_10", "accumulate_2", - "greet_enumerator", - "enumerate", - "add_three", "sum", "diff", "greet_one_last_time", @@ -837,8 +847,11 @@ def pipeline_that_has_a_component_with_only_default_inputs(): ) -@given("a pipeline that has a component with only default inputs as first to run", target_fixture="pipeline_data") -def pipeline_that_has_a_component_with_only_default_inputs_as_first_to_run(): +@given( + "a pipeline that has a component with only default inputs as first to run and receives inputs from a loop", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_component_with_only_default_inputs_as_first_to_run_and_receives_inputs_from_a_loop(): """ This tests verifies that a Pipeline doesn't get stuck running in a loop if it has all the following characterics: @@ -1529,6 +1542,217 @@ def run(self, query_embedding: List[float]): ) +@given("a pipeline that is a simple agent", target_fixture="pipeline_data") +def that_is_a_simple_agent(): + search_message_template = """ + Given these web search results: + + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + + Be as brief as possible, max one sentence. + Answer the question: {{search_query}} + """ + + react_message_template = """ + Solve a question answering task with interleaving Thought, Action, Observation steps. + + Thought reasons about the current situation + + Action can be: + google_search - Searches Google for the exact concept/entity (given in square brackets) and returns the results for you to use + finish - Returns the final answer (given in square brackets) and finishes the task + + Observation summarizes the Action outcome and helps in formulating the next + Thought in Thought, Action, Observation interleaving triplet of steps. + + After each Observation, provide the next Thought and next Action. + Don't execute multiple steps even though you know the answer. + Only generate Thought and Action, never Observation, you'll get Observation from Action. + Follow the pattern in the example below. + + Example: + ########################### + Question: Which magazine was started first Arthur’s Magazine or First for Women? + Thought: I need to search Arthur’s Magazine and First for Women, and find which was started + first. + Action: google_search[When was 'Arthur’s Magazine' started?] + Observation: Arthur’s Magazine was an American literary periodical ˘ + published in Philadelphia and founded in 1844. Edited by Timothy Shay Arthur, it featured work by + Edgar A. Poe, J.H. Ingraham, Sarah Josepha Hale, Thomas G. Spear, and others. In May 1846 + it was merged into Godey’s Lady’s Book. + Thought: Arthur’s Magazine was started in 1844. I need to search First for Women founding date next + Action: google_search[When was 'First for Women' magazine started?] + Observation: First for Women is a woman’s magazine published by Bauer Media Group in the + USA. The magazine was started in 1989. It is based in Englewood Cliffs, New Jersey. In 2011 + the circulation of the magazine was 1,310,696 copies. + Thought: First for Women was started in 1989. 1844 (Arthur’s Magazine) ¡ 1989 (First for + Women), so Arthur’s Magazine was started first. + Action: finish[Arthur’s Magazine] + ############################ + + Let's start, the question is: {{query}} + + Thought: + """ + + routes = [ + { + "condition": "{{'search' in tool_id_and_param[0]}}", + "output": "{{tool_id_and_param[1]}}", + "output_name": "search", + "output_type": str, + }, + { + "condition": "{{'finish' in tool_id_and_param[0]}}", + "output": "{{tool_id_and_param[1]}}", + "output_name": "finish", + "output_type": str, + }, + ] + + @component + class FakeThoughtActionOpenAIChatGenerator: + run_counter = 0 + + @component.output_types(replies=List[ChatMessage]) + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + if self.run_counter == 0: + self.run_counter += 1 + return { + "replies": [ + ChatMessage.from_assistant( + "thinking\n Action: google_search[What is taller, Eiffel Tower or Leaning Tower of Pisa]\n" + ) + ] + } + + return {"replies": [ChatMessage.from_assistant("thinking\n Action: finish[Eiffel Tower]\n")]} + + @component + class FakeConclusionOpenAIChatGenerator: + @component.output_types(replies=List[ChatMessage]) + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + return {"replies": [ChatMessage.from_assistant("Tower of Pisa is 55 meters tall\n")]} + + @component + class FakeSerperDevWebSearch: + @component.output_types(documents=List[Document]) + def run(self, query: str): + return { + "documents": [ + Document(content="Eiffel Tower is 300 meters tall"), + Document(content="Tower of Pisa is 55 meters tall"), + ] + } + + # main part + pipeline = Pipeline() + pipeline.add_component("main_input", BranchJoiner(List[ChatMessage])) + pipeline.add_component("prompt_builder", ChatPromptBuilder(variables=["query"])) + pipeline.add_component("llm", FakeThoughtActionOpenAIChatGenerator()) + + @component + class ToolExtractor: + @component.output_types(output=List[str]) + def run(self, messages: List[ChatMessage]): + prompt: str = messages[-1].content + lines = prompt.strip().split("\n") + for line in reversed(lines): + pattern = r"Action:\s*(\w+)\[(.*?)\]" + + match = re.search(pattern, line) + if match: + action_name = match.group(1) + parameter = match.group(2) + return {"output": [action_name, parameter]} + return {"output": [None, None]} + + pipeline.add_component("tool_extractor", ToolExtractor()) + + @component + class PromptConcatenator: + def __init__(self, suffix: str = ""): + self._suffix = suffix + + @component.output_types(output=List[ChatMessage]) + def run(self, replies: List[ChatMessage], current_prompt: List[ChatMessage]): + content = current_prompt[-1].content + replies[-1].content + self._suffix + return {"output": [ChatMessage.from_user(content)]} + + @component + class SearchOutputAdapter: + @component.output_types(output=List[ChatMessage]) + def run(self, replies: List[ChatMessage]): + content = f"Observation: {replies[-1].content}\n" + return {"output": [ChatMessage.from_assistant(content)]} + + pipeline.add_component("prompt_concatenator_after_action", PromptConcatenator()) + + pipeline.add_component("router", ConditionalRouter(routes)) + pipeline.add_component("router_search", FakeSerperDevWebSearch()) + pipeline.add_component("search_prompt_builder", ChatPromptBuilder(variables=["documents", "search_query"])) + pipeline.add_component("search_llm", FakeConclusionOpenAIChatGenerator()) + + pipeline.add_component("search_output_adapter", SearchOutputAdapter()) + pipeline.add_component("prompt_concatenator_after_observation", PromptConcatenator(suffix="\nThought: ")) + + # main + pipeline.connect("main_input", "prompt_builder.template") + pipeline.connect("prompt_builder.prompt", "llm.messages") + pipeline.connect("llm.replies", "prompt_concatenator_after_action.replies") + + # tools + pipeline.connect("prompt_builder.prompt", "prompt_concatenator_after_action.current_prompt") + pipeline.connect("prompt_concatenator_after_action", "tool_extractor.messages") + + pipeline.connect("tool_extractor", "router") + pipeline.connect("router.search", "router_search.query") + pipeline.connect("router_search.documents", "search_prompt_builder.documents") + pipeline.connect("router.search", "search_prompt_builder.search_query") + pipeline.connect("search_prompt_builder.prompt", "search_llm.messages") + + pipeline.connect("search_llm.replies", "search_output_adapter.replies") + pipeline.connect("search_output_adapter", "prompt_concatenator_after_observation.replies") + pipeline.connect("prompt_concatenator_after_action", "prompt_concatenator_after_observation.current_prompt") + pipeline.connect("prompt_concatenator_after_observation", "main_input") + + search_message = [ChatMessage.from_user(search_message_template)] + messages = [ChatMessage.from_user(react_message_template)] + question = "which tower is taller: eiffel tower or tower of pisa?" + + return pipeline, [ + PipelineRunData( + inputs={ + "main_input": {"value": messages}, + "prompt_builder": {"query": question}, + "search_prompt_builder": {"template": search_message}, + }, + expected_outputs={"router": {"finish": "Eiffel Tower"}}, + expected_run_order=[ + "main_input", + "prompt_builder", + "llm", + "prompt_concatenator_after_action", + "tool_extractor", + "router", + "router_search", + "search_prompt_builder", + "search_llm", + "search_output_adapter", + "prompt_concatenator_after_observation", + "main_input", + "prompt_builder", + "llm", + "prompt_concatenator_after_action", + "tool_extractor", + "router", + ], + ) + ] + + @given("a pipeline that has a variadic component that receives partial inputs", target_fixture="pipeline_data") def that_has_a_variadic_component_that_receives_partial_inputs(): @component @@ -1566,7 +1790,7 @@ def run(self, create_document: bool = False): ] }, }, - expected_run_order=["first_creator", "third_creator", "second_creator", "documents_joiner"], + expected_run_order=["first_creator", "second_creator", "third_creator", "documents_joiner"], ), PipelineRunData( inputs={"first_creator": {"create_document": True}, "second_creator": {"create_document": True}}, @@ -1627,3 +1851,347 @@ def that_has_an_answer_joiner_variadic_component(): ) ], ) + + +@given( + "a pipeline that is linear and a component in the middle receives optional input from other components and input from the user", + target_fixture="pipeline_data", +) +def that_is_linear_and_a_component_in_the_middle_receives_optional_input_from_other_components_and_input_from_the_user(): + @component + class QueryMetadataExtractor: + @component.output_types(filters=Dict[str, str]) + def run(self, prompt: str): + metadata = json.loads(prompt) + filters = [] + for key, value in metadata.items(): + filters.append({"field": f"meta.{key}", "operator": "==", "value": value}) + + return {"filters": {"operator": "AND", "conditions": filters}} + + documents = [ + Document( + content="some publication about Alzheimer prevention research done over 2023 patients study", + meta={"year": 2022, "disease": "Alzheimer", "author": "Michael Butter"}, + id="doc1", + ), + Document( + content="some text about investigation and treatment of Alzheimer disease", + meta={"year": 2023, "disease": "Alzheimer", "author": "John Bread"}, + id="doc2", + ), + Document( + content="A study on the effectiveness of new therapies for Parkinson's disease", + meta={"year": 2022, "disease": "Parkinson", "author": "Alice Smith"}, + id="doc3", + ), + Document( + content="An overview of the latest research on the genetics of Parkinson's disease and its implications for treatment", + meta={"year": 2023, "disease": "Parkinson", "author": "David Jones"}, + id="doc4", + ), + ] + document_store = InMemoryDocumentStore(bm25_algorithm="BM25Plus") + document_store.write_documents(documents=documents, policy=DuplicatePolicy.OVERWRITE) + + pipeline = Pipeline() + pipeline.add_component(instance=PromptBuilder('{"disease": "Alzheimer", "year": 2023}'), name="builder") + pipeline.add_component(instance=QueryMetadataExtractor(), name="metadata_extractor") + pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever") + pipeline.add_component(instance=DocumentJoiner(), name="document_joiner") + + pipeline.connect("builder.prompt", "metadata_extractor.prompt") + pipeline.connect("metadata_extractor.filters", "retriever.filters") + pipeline.connect("retriever.documents", "document_joiner.documents") + + query = "publications 2023 Alzheimer's disease" + + return ( + pipeline, + [ + PipelineRunData( + inputs={"retriever": {"query": query}}, + expected_outputs={ + "document_joiner": { + "documents": [ + Document( + content="some text about investigation and treatment of Alzheimer disease", + meta={"year": 2023, "disease": "Alzheimer", "author": "John Bread"}, + id="doc2", + score=3.324112496100923, + ) + ] + } + }, + expected_run_order=["builder", "metadata_extractor", "retriever", "document_joiner"], + ) + ], + ) + + +@given("a pipeline that has a cycle that would get it stuck", target_fixture="pipeline_data") +def that_has_a_cycle_that_would_get_it_stuck(): + template = """ + You are an experienced and accurate Turkish CX speacialist that classifies customer comments into pre-defined categories below:\n + Negative experience labels: + - Late delivery + - Rotten/spoilt item + - Bad Courier behavior + + Positive experience labels: + - Good courier behavior + - Thanks & appreciation + - Love message to courier + - Fast delivery + - Quality of products + + Create a JSON object as a response. The fields are: 'positive_experience', 'negative_experience'. + Assign at least one of the pre-defined labels to the given customer comment under positive and negative experience fields. + If the comment has a positive experience, list the label under 'positive_experience' field. + If the comments has a negative_experience, list it under the 'negative_experience' field. + Here is the comment:\n{{ comment }}\n. Just return the category names in the list. If there aren't any, return an empty list. + + {% if invalid_replies and error_message %} + You already created the following output in a previous attempt: {{ invalid_replies }} + However, this doesn't comply with the format requirements from above and triggered this Python exception: {{ error_message }} + Correct the output and try again. Just return the corrected output without any extra explanations. + {% endif %} + """ + prompt_builder = PromptBuilder( + template=template, required_variables=["comment", "invalid_replies", "error_message"] + ) + + @component + class FakeOutputValidator: + @component.output_types( + valid_replies=List[str], invalid_replies=Optional[List[str]], error_message=Optional[str] + ) + def run(self, replies: List[str]): + if not getattr(self, "called", False): + self.called = True + return {"invalid_replies": ["This is an invalid reply"], "error_message": "this is an error message"} + return {"valid_replies": replies} + + @component + class FakeGenerator: + @component.output_types(replies=List[str]) + def run(self, prompt: str): + return {"replies": ["This is a valid reply"]} + + llm = FakeGenerator() + validator = FakeOutputValidator() + + pipeline = Pipeline(max_runs_per_component=1) + pipeline.add_component("prompt_builder", prompt_builder) + + pipeline.add_component("llm", llm) + pipeline.add_component("output_validator", validator) + + pipeline.connect("prompt_builder.prompt", "llm.prompt") + pipeline.connect("llm.replies", "output_validator.replies") + pipeline.connect("output_validator.invalid_replies", "prompt_builder.invalid_replies") + + pipeline.connect("output_validator.error_message", "prompt_builder.error_message") + + comment = "I loved the quality of the meal but the courier was rude" + return (pipeline, [PipelineRunData(inputs={"prompt_builder": {"comment": comment}})]) + + +@given("a pipeline that has a loop in the middle", target_fixture="pipeline_data") +def that_has_a_loop_in_the_middle(): + @component + class FakeGenerator: + @component.output_types(replies=List[str]) + def run(self, prompt: str): + replies = [] + if getattr(self, "first_run", True): + self.first_run = False + replies.append("No answer") + else: + replies.append("42") + return {"replies": replies} + + @component + class PromptCleaner: + @component.output_types(clean_prompt=str) + def run(self, prompt: str): + return {"clean_prompt": prompt.strip()} + + routes = [ + { + "condition": "{{ 'No answer' in replies }}", + "output": "{{ replies }}", + "output_name": "invalid_replies", + "output_type": List[str], + }, + { + "condition": "{{ 'No answer' not in replies }}", + "output": "{{ replies }}", + "output_name": "valid_replies", + "output_type": List[str], + }, + ] + + pipeline = Pipeline(max_runs_per_component=20) + pipeline.add_component("prompt_cleaner", PromptCleaner()) + pipeline.add_component("prompt_builder", PromptBuilder(template="", variables=["question", "invalid_replies"])) + pipeline.add_component("llm", FakeGenerator()) + pipeline.add_component("answer_validator", ConditionalRouter(routes=routes)) + pipeline.add_component("answer_builder", AnswerBuilder()) + + pipeline.connect("prompt_cleaner.clean_prompt", "prompt_builder.template") + pipeline.connect("prompt_builder.prompt", "llm.prompt") + pipeline.connect("llm.replies", "answer_validator.replies") + pipeline.connect("answer_validator.invalid_replies", "prompt_builder.invalid_replies") + pipeline.connect("answer_validator.valid_replies", "answer_builder.replies") + + question = "What is the answer?" + return ( + pipeline, + [ + PipelineRunData( + inputs={ + "prompt_cleaner": {"prompt": "Random template"}, + "prompt_builder": {"question": question}, + "answer_builder": {"query": question}, + }, + expected_outputs={ + "answer_builder": {"answers": [GeneratedAnswer(data="42", query=question, documents=[])]} + }, + expected_run_order=[ + "prompt_cleaner", + "prompt_builder", + "llm", + "answer_validator", + "prompt_builder", + "llm", + "answer_validator", + "answer_builder", + ], + ) + ], + ) + + +@given("a pipeline that has variadic component that receives a conditional input", target_fixture="pipeline_data") +def that_has_variadic_component_that_receives_a_conditional_input(): + pipe = Pipeline(max_runs_per_component=1) + routes = [ + { + "condition": "{{ documents|length > 1 }}", + "output": "{{ documents }}", + "output_name": "long", + "output_type": List[Document], + }, + { + "condition": "{{ documents|length <= 1 }}", + "output": "{{ documents }}", + "output_name": "short", + "output_type": List[Document], + }, + ] + + @component + class NoOp: + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + return {"documents": documents} + + @component + class CommaSplitter: + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + res = [] + current_id = 0 + for doc in documents: + for split in doc.content.split(","): + res.append(Document(content=split, id=str(current_id))) + current_id += 1 + return {"documents": res} + + pipe.add_component("conditional_router", ConditionalRouter(routes, unsafe=True)) + pipe.add_component( + "empty_lines_cleaner", DocumentCleaner(remove_empty_lines=True, remove_extra_whitespaces=False, keep_id=True) + ) + pipe.add_component("comma_splitter", CommaSplitter()) + pipe.add_component("document_cleaner", DocumentCleaner(keep_id=True)) + pipe.add_component("document_joiner", DocumentJoiner()) + + pipe.add_component("noop2", NoOp()) + pipe.add_component("noop3", NoOp()) + + pipe.connect("noop2", "noop3") + pipe.connect("noop3", "conditional_router") + + pipe.connect("conditional_router.long", "empty_lines_cleaner") + pipe.connect("empty_lines_cleaner", "document_joiner") + + pipe.connect("comma_splitter", "document_cleaner") + pipe.connect("document_cleaner", "document_joiner") + pipe.connect("comma_splitter", "document_joiner") + + document = Document( + id="1000", content="This document has so many, sentences. Like this one, or this one. Or even this other one." + ) + + return pipe, [ + PipelineRunData( + inputs={"noop2": {"documents": [document]}, "comma_splitter": {"documents": [document]}}, + expected_outputs={ + "conditional_router": { + "short": [ + Document( + id="1000", + content="This document has so many, sentences. Like this one, or this one. Or even this other one.", + ) + ] + }, + "document_joiner": { + "documents": [ + Document(id="0", content="This document has so many"), + Document(id="1", content=" sentences. Like this one"), + Document(id="2", content=" or this one. Or even this other one."), + ] + }, + }, + expected_run_order=[ + "comma_splitter", + "noop2", + "document_cleaner", + "noop3", + "conditional_router", + "document_joiner", + ], + ), + PipelineRunData( + inputs={ + "noop2": {"documents": [document, document]}, + "comma_splitter": {"documents": [document, document]}, + }, + expected_outputs={ + "document_joiner": { + "documents": [ + Document(id="0", content="This document has so many"), + Document(id="1", content=" sentences. Like this one"), + Document(id="2", content=" or this one. Or even this other one."), + Document(id="3", content="This document has so many"), + Document(id="4", content=" sentences. Like this one"), + Document(id="5", content=" or this one. Or even this other one."), + Document( + id="1000", + content="This document has so many, sentences. Like this one, or this one. Or even this other one.", + ), + ] + } + }, + expected_run_order=[ + "comma_splitter", + "noop2", + "document_cleaner", + "noop3", + "conditional_router", + "empty_lines_cleaner", + "document_joiner", + ], + ), + ] diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 3a07c049fc..1cd57a5b5b 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -11,7 +11,7 @@ from haystack.components.builders import PromptBuilder, AnswerBuilder from haystack.components.joiners import BranchJoiner from haystack.core.component import component -from haystack.core.component.types import InputSocket, OutputSocket, Variadic, GreedyVariadic +from haystack.core.component.types import InputSocket, OutputSocket, Variadic, GreedyVariadic, _empty from haystack.core.errors import DeserializationError, PipelineConnectError, PipelineDrawingError, PipelineError from haystack.core.pipeline import Pipeline, PredefinedPipeline from haystack.core.pipeline.base import ( @@ -788,43 +788,7 @@ def test__init_graph(self): for node in pipe.graph.nodes: assert pipe.graph.nodes[node]["visits"] == 0 - def test__init_run_queue(self): - ComponentWithVariadic = component_class( - "ComponentWithVariadic", input_types={"in": Variadic[int]}, output_types={"out": int} - ) - ComponentWithNoInputs = component_class("ComponentWithNoInputs", input_types={}, output_types={"out": int}) - ComponentWithSingleInput = component_class( - "ComponentWithSingleInput", input_types={"in": int}, output_types={"out": int} - ) - ComponentWithMultipleInputs = component_class( - "ComponentWithMultipleInputs", input_types={"in1": int, "in2": int}, output_types={"out": int} - ) - - pipe = Pipeline() - pipe.add_component("with_variadic", ComponentWithVariadic()) - pipe.add_component("with_no_inputs", ComponentWithNoInputs()) - pipe.add_component("with_single_input", ComponentWithSingleInput()) - pipe.add_component("another_with_single_input", ComponentWithSingleInput()) - pipe.add_component("yet_another_with_single_input", ComponentWithSingleInput()) - pipe.add_component("with_multiple_inputs", ComponentWithMultipleInputs()) - - pipe.connect("yet_another_with_single_input.out", "with_variadic.in") - pipe.connect("with_no_inputs.out", "with_variadic.in") - pipe.connect("with_single_input.out", "another_with_single_input.in") - pipe.connect("another_with_single_input.out", "with_multiple_inputs.in1") - pipe.connect("with_multiple_inputs.out", "with_variadic.in") - - data = {"yet_another_with_single_input": {"in": 1}} - run_queue = pipe._init_run_queue(data) - assert len(run_queue) == 6 - assert run_queue[0][0] == "with_no_inputs" - assert run_queue[1][0] == "with_single_input" - assert run_queue[2][0] == "yet_another_with_single_input" - assert run_queue[3][0] == "another_with_single_input" - assert run_queue[4][0] == "with_multiple_inputs" - assert run_queue[5][0] == "with_variadic" - - def test__init_inputs_state(self): + def test__normalize_varidiac_input_data(self): pipe = Pipeline() template = """ Answer the following questions: @@ -838,13 +802,12 @@ def test__init_inputs_state(self): "branch_joiner": {"value": 1}, "not_a_component": "some input data", } - res = pipe._init_inputs_state(data) + res = pipe._normalize_varidiac_input_data(data) assert res == { "prompt_builder": {"questions": ["What is the capital of Italy?", "What is the capital of France?"]}, "branch_joiner": {"value": [1]}, "not_a_component": "some input data", } - assert id(questions) != id(res["prompt_builder"]["questions"]) def test__prepare_component_input_data(self): MockComponent = component_class("MockComponent", input_types={"x": List[str], "y": str}) @@ -1165,6 +1128,30 @@ def test__find_components_that_will_receive_no_input(self): ) assert res == set() + multiple_outputs = component_class("MultipleOutputs", output_types={"first": int, "second": int})() + + def custom_init(self): + component.set_input_type(self, "first", Optional[int], 1) + component.set_input_type(self, "second", Optional[int], 2) + + multiple_optional_inputs = component_class("MultipleOptionalInputs", extra_fields={"__init__": custom_init})() + + pipe = Pipeline() + pipe.add_component("multiple_outputs", multiple_outputs) + pipe.add_component("multiple_optional_inputs", multiple_optional_inputs) + pipe.connect("multiple_outputs.second", "multiple_optional_inputs.first") + + res = pipe._find_components_that_will_receive_no_input("multiple_outputs", {"first": 1}, {}) + assert res == {("multiple_optional_inputs", multiple_optional_inputs)} + + res = pipe._find_components_that_will_receive_no_input( + "multiple_outputs", {"first": 1}, {"multiple_optional_inputs": {"second": 200}} + ) + assert res == set() + + res = pipe._find_components_that_will_receive_no_input("multiple_outputs", {"second": 1}, {}) + assert res == set() + def test__distribute_output(self): document_builder = component_class( "DocumentBuilder", input_types={"text": str}, output_types={"doc": Document, "another_doc": Document} @@ -1184,12 +1171,20 @@ def test__distribute_output(self): inputs = {"document_builder": {"text": "some text"}} run_queue = [] waiting_queue = [("document_joiner", document_joiner)] + receivers = [ + ( + "document_cleaner", + OutputSocket("doc", Document, ["document_cleaner"]), + InputSocket("doc", Document, _empty, ["document_builder"]), + ), + ( + "document_joiner", + OutputSocket("another_doc", Document, ["document_joiner"]), + InputSocket("docs", Variadic[Document], _empty, ["document_builder"]), + ), + ] res = pipe._distribute_output( - "document_builder", - {"doc": Document("some text"), "another_doc": Document()}, - inputs, - run_queue, - waiting_queue, + receivers, {"doc": Document("some text"), "another_doc": Document()}, inputs, run_queue, waiting_queue ) assert res == {} @@ -1524,3 +1519,65 @@ def test__is_lazy_variadic(self): assert not _is_lazy_variadic(NonVariadic()) assert _is_lazy_variadic(VariadicNonGreedyVariadic()) assert not _is_lazy_variadic(NonVariadicAndGreedyVariadic()) + + def test__find_receivers_from(self): + sentence_builder = component_class( + "SentenceBuilder", input_types={"words": List[str]}, output_types={"text": str} + )() + document_builder = component_class( + "DocumentBuilder", input_types={"text": str}, output_types={"doc": Document} + )() + conditional_document_builder = component_class( + "ConditionalDocumentBuilder", output_types={"doc": Document, "noop": None} + )() + + document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})() + + pipe = Pipeline() + pipe.add_component("sentence_builder", sentence_builder) + pipe.add_component("document_builder", document_builder) + pipe.add_component("document_joiner", document_joiner) + pipe.add_component("conditional_document_builder", conditional_document_builder) + pipe.connect("sentence_builder.text", "document_builder.text") + pipe.connect("document_builder.doc", "document_joiner.docs") + pipe.connect("conditional_document_builder.doc", "document_joiner.docs") + + res = pipe._find_receivers_from("sentence_builder") + assert res == [ + ( + "document_builder", + OutputSocket(name="text", type=str, receivers=["document_builder"]), + InputSocket(name="text", type=str, default_value=_empty, senders=["sentence_builder"]), + ) + ] + + res = pipe._find_receivers_from("document_builder") + assert res == [ + ( + "document_joiner", + OutputSocket(name="doc", type=Document, receivers=["document_joiner"]), + InputSocket( + name="docs", + type=Variadic[Document], + default_value=_empty, + senders=["document_builder", "conditional_document_builder"], + ), + ) + ] + + res = pipe._find_receivers_from("document_joiner") + assert res == [] + + res = pipe._find_receivers_from("conditional_document_builder") + assert res == [ + ( + "document_joiner", + OutputSocket(name="doc", type=Document, receivers=["document_joiner"]), + InputSocket( + name="docs", + type=Variadic[Document], + default_value=_empty, + senders=["document_builder", "conditional_document_builder"], + ), + ) + ] diff --git a/test/core/pipeline/test_validation_pipeline_io.py b/test/core/pipeline/test_validation_pipeline_io.py index f9160799fe..5ca1c08098 100644 --- a/test/core/pipeline/test_validation_pipeline_io.py +++ b/test/core/pipeline/test_validation_pipeline_io.py @@ -6,9 +6,9 @@ import pytest from haystack.core.component.types import InputSocket, OutputSocket, Variadic -from haystack.core.errors import PipelineValidationError from haystack.core.pipeline import Pipeline from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs +from haystack.testing.factory import component_class from haystack.testing.sample_components import AddFixedValue, Double, Parity, Sum @@ -119,10 +119,16 @@ def test_find_pipeline_some_outputs_different_components(): def test_validate_pipeline_input_pipeline_with_no_inputs(): pipe = Pipeline() - pipe.add_component("comp1", Double()) - pipe.add_component("comp2", Double()) - pipe.connect("comp1", "comp2") - pipe.connect("comp2", "comp1") + NoInputs = component_class("NoInputs", input_types={}, output={"value": 10}) + pipe.add_component("no_inputs", NoInputs()) + res = pipe.run({}) + assert res == {"no_inputs": {"value": 10}} + + +def test_validate_pipeline_input_pipeline_with_no_inputs_no_outputs(): + pipe = Pipeline() + NoIO = component_class("NoIO", input_types={}, output={}) + pipe.add_component("no_inputs", NoIO()) res = pipe.run({}) assert res == {}