From 22fcc4b48c483fdc760939b02ae76280172496bc Mon Sep 17 00:00:00 2001 From: shadeMe Date: Wed, 20 Nov 2024 12:30:09 +0100 Subject: [PATCH] fix: Lints --- .../core/pipeline/async_pipeline.py | 121 +++++++++++++----- 1 file changed, 92 insertions(+), 29 deletions(-) diff --git a/haystack_experimental/core/pipeline/async_pipeline.py b/haystack_experimental/core/pipeline/async_pipeline.py index 74244dea..4204e407 100644 --- a/haystack_experimental/core/pipeline/async_pipeline.py +++ b/haystack_experimental/core/pipeline/async_pipeline.py @@ -60,7 +60,9 @@ def __init__( # We only need one thread as we'll immediately block after launching it. self.executor = ( - ThreadPoolExecutor(thread_name_prefix=f"async-pipeline-executor-{id(self)}", max_workers=1) + ThreadPoolExecutor( + thread_name_prefix=f"async-pipeline-executor-{id(self)}", max_workers=1 + ) if async_executor is None else async_executor ) @@ -88,17 +90,27 @@ async def _run_component( tags={ "haystack.component.name": name, "haystack.component.type": instance.__class__.__name__, - "haystack.component.input_types": {k: type(v).__name__ for k, v in inputs.items()}, + "haystack.component.input_types": { + k: type(v).__name__ for k, v in inputs.items() + }, "haystack.component.input_spec": { key: { - "type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)), + "type": ( + value.type.__name__ + if isinstance(value.type, type) + else str(value.type) + ), "senders": value.senders, } for key, value in instance.__haystack_input__._sockets_dict.items() # type: ignore }, "haystack.component.output_spec": { key: { - "type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)), + "type": ( + value.type.__name__ + if isinstance(value.type, type) + else str(value.type) + ), "receivers": value.receivers, } for key, value in instance.__haystack_output__._sockets_dict.items() # type: ignore @@ -113,14 +125,18 @@ async def _run_component( res: Dict[str, Any] if instance.__haystack_supports_async__: # type: ignore - logger.info("Running async component {component_name}", component_name=name) + logger.info( + "Running async component {component_name}", component_name=name + ) res = await instance.run_async(**inputs) # type: ignore else: logger.info( "Running sync component {component_name} on executor", component_name=name, ) - res = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: instance.run(**inputs)) + res = await asyncio.get_event_loop().run_in_executor( + self.executor, lambda: instance.run(**inputs) + ) self.graph.nodes[name]["visits"] += 1 # After a Component that has variadic inputs is run, we need to reset the variadic inputs that were consumed @@ -187,7 +203,9 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912 while not cycle_received_inputs: # Here we run the Components name, comp = run_queue.pop(0) - if _is_lazy_variadic(comp) and not all(_is_lazy_variadic(comp) for _, comp in run_queue): + if _is_lazy_variadic(comp) and not all( + _is_lazy_variadic(comp) for _, comp in run_queue + ): # We run Components with lazy variadic inputs only if there only Components with # lazy variadic inputs left to run _enqueue_waiting_component((name, comp), waiting_queue) @@ -199,7 +217,9 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912 msg = f"Maximum run count {self._max_runs_per_component} reached for component '{name}'" raise PipelineMaxComponentRuns(msg) - res: Dict[str, Any] = await self._run_component(name, components_inputs[name]) + res: Dict[str, Any] = await self._run_component( + name, components_inputs[name] + ) yield {name: deepcopy(res)}, False # Delete the inputs that were consumed by the Component and are not received from @@ -238,12 +258,18 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912 # We manage to run this component that was in the waiting list, we can remove it. # This happens when a component was put in the waiting list but we reached it from another edge. _dequeue_waiting_component((name, comp), waiting_queue) - for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs): + for pair in self._find_components_that_will_receive_no_input( + name, res, components_inputs + ): _dequeue_component(pair, run_queue, waiting_queue) - receivers = [item for item in self._find_receivers_from(name) if item[0] in cycle] + receivers = [ + item for item in self._find_receivers_from(name) if item[0] in cycle + ] - res = self._distribute_output(receivers, res, components_inputs, run_queue, waiting_queue) + res = self._distribute_output( + receivers, res, components_inputs, run_queue, waiting_queue + ) # We treat a cycle as a completely independent graph, so we keep track of output # that is not sent inside the cycle. @@ -274,21 +300,31 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912 warn(RuntimeWarning(msg)) break - (name, comp) = self._find_next_runnable_lazy_variadic_or_default_component(waiting_queue) + (name, comp) = ( + self._find_next_runnable_lazy_variadic_or_default_component( + waiting_queue + ) + ) _add_missing_input_defaults(name, comp, components_inputs) _enqueue_component((name, comp), run_queue, waiting_queue) continue - before_last_waiting_queue = last_waiting_queue.copy() if last_waiting_queue is not None else None + before_last_waiting_queue = ( + last_waiting_queue.copy() + if last_waiting_queue is not None + else None + ) last_waiting_queue = {item[0] for item in waiting_queue} - (name, comp) = self._find_next_runnable_component(components_inputs, waiting_queue) + (name, comp) = self._find_next_runnable_component( + components_inputs, waiting_queue + ) _add_missing_input_defaults(name, comp, components_inputs) _enqueue_component((name, comp), run_queue, waiting_queue) yield subgraph_outputs, True - async def run( # noqa: PLR0915 + async def run( # noqa: PLR0915, PLR0912 self, data: Dict[str, Any], ) -> AsyncIterator[Dict[str, Any]]: @@ -368,7 +404,9 @@ def run(self, word: str): self._validate_input(data) # Normalize the input data - components_inputs: Dict[str, Dict[str, Any]] = self._normalize_varidiac_input_data(data) + components_inputs: Dict[str, Dict[str, Any]] = ( + self._normalize_varidiac_input_data(data) + ) # These variables are used to detect when we're stuck in a loop. # Stuck loops can happen when one or more components are waiting for input but @@ -391,7 +429,9 @@ def run(self, word: str): # Break cycles in case there are, this is a noop if no cycle is found. # This will raise if a cycle can't be broken. - graph_without_cycles, components_in_cycles = self._break_supported_cycles_in_graph() + graph_without_cycles, components_in_cycles = ( + self._break_supported_cycles_in_graph() + ) run_queue: List[Tuple[str, Component]] = [] for node in nx.topological_sort(graph_without_cycles): @@ -426,14 +466,16 @@ def run(self, word: str): while len(run_queue) > 0: name, comp = run_queue.pop(0) - if _is_lazy_variadic(comp) and not all(_is_lazy_variadic(comp) for _, comp in run_queue): + if _is_lazy_variadic(comp) and not all( + _is_lazy_variadic(comp) for _, comp in run_queue + ): # We run Components with lazy variadic inputs only if there only Components with # lazy variadic inputs left to run _enqueue_waiting_component((name, comp), waiting_queue) continue - if self._component_has_enough_inputs_to_run(name, components_inputs) and components_in_cycles.get( - name, [] - ): + if self._component_has_enough_inputs_to_run( + name, components_inputs + ) and components_in_cycles.get(name, []): cycles = components_in_cycles.get(name, []) # This component is part of one or more cycles, let's get the first one and run it. @@ -474,13 +516,17 @@ def run(self, word: str): msg = f"Maximum run count {self._max_runs_per_component} reached for component '{name}'" raise PipelineMaxComponentRuns(msg) - res: Dict[str, Any] = await self._run_component(name, components_inputs[name], parent_span=span) + res: Dict[str, Any] = await self._run_component( + name, components_inputs[name], parent_span=span + ) yield {name: deepcopy(res)} # Delete the inputs that were consumed by the Component and are not received from the user sockets = list(components_inputs[name].keys()) for socket_name in sockets: - senders = comp.__haystack_input__._sockets_dict[socket_name].senders + senders = comp.__haystack_input__._sockets_dict[ + socket_name + ].senders if senders: # Delete all inputs that are received from other Components del components_inputs[name][socket_name] @@ -494,10 +540,14 @@ def run(self, word: str): # This happens when a component was put in the waiting list but we reached it from another edge. _dequeue_waiting_component((name, comp), waiting_queue) - for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs): + for pair in self._find_components_that_will_receive_no_input( + name, res, components_inputs + ): _dequeue_component(pair, run_queue, waiting_queue) receivers = self._find_receivers_from(name) - res = self._distribute_output(receivers, res, components_inputs, run_queue, waiting_queue) + res = self._distribute_output( + receivers, res, components_inputs, run_queue, waiting_queue + ) if len(res) > 0: final_outputs[name] = res @@ -523,15 +573,25 @@ def run(self, word: str): warn(RuntimeWarning(msg)) break - (name, comp) = self._find_next_runnable_lazy_variadic_or_default_component(waiting_queue) + (name, comp) = ( + self._find_next_runnable_lazy_variadic_or_default_component( + waiting_queue + ) + ) _add_missing_input_defaults(name, comp, components_inputs) _enqueue_component((name, comp), run_queue, waiting_queue) continue - before_last_waiting_queue = last_waiting_queue.copy() if last_waiting_queue is not None else None + before_last_waiting_queue = ( + last_waiting_queue.copy() + if last_waiting_queue is not None + else None + ) last_waiting_queue = {item[0] for item in waiting_queue} - (name, comp) = self._find_next_runnable_component(components_inputs, waiting_queue) + (name, comp) = self._find_next_runnable_component( + components_inputs, waiting_queue + ) _add_missing_input_defaults(name, comp, components_inputs) _enqueue_component((name, comp), run_queue, waiting_queue) @@ -567,7 +627,10 @@ async def run_async_pipeline( outputs = [x async for x in pipeline.run(data)] intermediate_outputs = { - k: v for d in outputs[:-1] for k, v in d.items() if include_outputs_from is None or k in include_outputs_from + k: v + for d in outputs[:-1] + for k, v in d.items() + if include_outputs_from is None or k in include_outputs_from } final_output = outputs[-1]