Skip to content

Commit

Permalink
fix: Lints
Browse files Browse the repository at this point in the history
  • Loading branch information
shadeMe committed Nov 20, 2024
1 parent 8b11072 commit 22fcc4b
Showing 1 changed file with 92 additions and 29 deletions.
121 changes: 92 additions & 29 deletions haystack_experimental/core/pipeline/async_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def __init__(

# We only need one thread as we'll immediately block after launching it.
self.executor = (
ThreadPoolExecutor(thread_name_prefix=f"async-pipeline-executor-{id(self)}", max_workers=1)
ThreadPoolExecutor(
thread_name_prefix=f"async-pipeline-executor-{id(self)}", max_workers=1
)
if async_executor is None
else async_executor
)
Expand Down Expand Up @@ -88,17 +90,27 @@ async def _run_component(
tags={
"haystack.component.name": name,
"haystack.component.type": instance.__class__.__name__,
"haystack.component.input_types": {k: type(v).__name__ for k, v in inputs.items()},
"haystack.component.input_types": {
k: type(v).__name__ for k, v in inputs.items()
},
"haystack.component.input_spec": {
key: {
"type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
"type": (
value.type.__name__
if isinstance(value.type, type)
else str(value.type)
),
"senders": value.senders,
}
for key, value in instance.__haystack_input__._sockets_dict.items() # type: ignore
},
"haystack.component.output_spec": {
key: {
"type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
"type": (
value.type.__name__
if isinstance(value.type, type)
else str(value.type)
),
"receivers": value.receivers,
}
for key, value in instance.__haystack_output__._sockets_dict.items() # type: ignore
Expand All @@ -113,14 +125,18 @@ async def _run_component(

res: Dict[str, Any]
if instance.__haystack_supports_async__: # type: ignore
logger.info("Running async component {component_name}", component_name=name)
logger.info(
"Running async component {component_name}", component_name=name
)
res = await instance.run_async(**inputs) # type: ignore
else:
logger.info(
"Running sync component {component_name} on executor",
component_name=name,
)
res = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: instance.run(**inputs))
res = await asyncio.get_event_loop().run_in_executor(
self.executor, lambda: instance.run(**inputs)
)
self.graph.nodes[name]["visits"] += 1

# After a Component that has variadic inputs is run, we need to reset the variadic inputs that were consumed
Expand Down Expand Up @@ -187,7 +203,9 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912
while not cycle_received_inputs:
# Here we run the Components
name, comp = run_queue.pop(0)
if _is_lazy_variadic(comp) and not all(_is_lazy_variadic(comp) for _, comp in run_queue):
if _is_lazy_variadic(comp) and not all(
_is_lazy_variadic(comp) for _, comp in run_queue
):
# We run Components with lazy variadic inputs only if there only Components with
# lazy variadic inputs left to run
_enqueue_waiting_component((name, comp), waiting_queue)
Expand All @@ -199,7 +217,9 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912
msg = f"Maximum run count {self._max_runs_per_component} reached for component '{name}'"
raise PipelineMaxComponentRuns(msg)

res: Dict[str, Any] = await self._run_component(name, components_inputs[name])
res: Dict[str, Any] = await self._run_component(
name, components_inputs[name]
)
yield {name: deepcopy(res)}, False

# Delete the inputs that were consumed by the Component and are not received from
Expand Down Expand Up @@ -238,12 +258,18 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912
# We manage to run this component that was in the waiting list, we can remove it.
# This happens when a component was put in the waiting list but we reached it from another edge.
_dequeue_waiting_component((name, comp), waiting_queue)
for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs):
for pair in self._find_components_that_will_receive_no_input(
name, res, components_inputs
):
_dequeue_component(pair, run_queue, waiting_queue)

receivers = [item for item in self._find_receivers_from(name) if item[0] in cycle]
receivers = [
item for item in self._find_receivers_from(name) if item[0] in cycle
]

res = self._distribute_output(receivers, res, components_inputs, run_queue, waiting_queue)
res = self._distribute_output(
receivers, res, components_inputs, run_queue, waiting_queue
)

# We treat a cycle as a completely independent graph, so we keep track of output
# that is not sent inside the cycle.
Expand Down Expand Up @@ -274,21 +300,31 @@ async def _run_subgraph( # noqa: PLR0915, PLR0912
warn(RuntimeWarning(msg))
break

(name, comp) = self._find_next_runnable_lazy_variadic_or_default_component(waiting_queue)
(name, comp) = (
self._find_next_runnable_lazy_variadic_or_default_component(
waiting_queue
)
)
_add_missing_input_defaults(name, comp, components_inputs)
_enqueue_component((name, comp), run_queue, waiting_queue)
continue

before_last_waiting_queue = last_waiting_queue.copy() if last_waiting_queue is not None else None
before_last_waiting_queue = (
last_waiting_queue.copy()
if last_waiting_queue is not None
else None
)
last_waiting_queue = {item[0] for item in waiting_queue}

(name, comp) = self._find_next_runnable_component(components_inputs, waiting_queue)
(name, comp) = self._find_next_runnable_component(
components_inputs, waiting_queue
)
_add_missing_input_defaults(name, comp, components_inputs)
_enqueue_component((name, comp), run_queue, waiting_queue)

yield subgraph_outputs, True

async def run( # noqa: PLR0915
async def run( # noqa: PLR0915, PLR0912
self,
data: Dict[str, Any],
) -> AsyncIterator[Dict[str, Any]]:
Expand Down Expand Up @@ -368,7 +404,9 @@ def run(self, word: str):
self._validate_input(data)

# Normalize the input data
components_inputs: Dict[str, Dict[str, Any]] = self._normalize_varidiac_input_data(data)
components_inputs: Dict[str, Dict[str, Any]] = (
self._normalize_varidiac_input_data(data)
)

# These variables are used to detect when we're stuck in a loop.
# Stuck loops can happen when one or more components are waiting for input but
Expand All @@ -391,7 +429,9 @@ def run(self, word: str):

# Break cycles in case there are, this is a noop if no cycle is found.
# This will raise if a cycle can't be broken.
graph_without_cycles, components_in_cycles = self._break_supported_cycles_in_graph()
graph_without_cycles, components_in_cycles = (
self._break_supported_cycles_in_graph()
)

run_queue: List[Tuple[str, Component]] = []
for node in nx.topological_sort(graph_without_cycles):
Expand Down Expand Up @@ -426,14 +466,16 @@ def run(self, word: str):
while len(run_queue) > 0:
name, comp = run_queue.pop(0)

if _is_lazy_variadic(comp) and not all(_is_lazy_variadic(comp) for _, comp in run_queue):
if _is_lazy_variadic(comp) and not all(
_is_lazy_variadic(comp) for _, comp in run_queue
):
# We run Components with lazy variadic inputs only if there only Components with
# lazy variadic inputs left to run
_enqueue_waiting_component((name, comp), waiting_queue)
continue
if self._component_has_enough_inputs_to_run(name, components_inputs) and components_in_cycles.get(
name, []
):
if self._component_has_enough_inputs_to_run(
name, components_inputs
) and components_in_cycles.get(name, []):
cycles = components_in_cycles.get(name, [])

# This component is part of one or more cycles, let's get the first one and run it.
Expand Down Expand Up @@ -474,13 +516,17 @@ def run(self, word: str):
msg = f"Maximum run count {self._max_runs_per_component} reached for component '{name}'"
raise PipelineMaxComponentRuns(msg)

res: Dict[str, Any] = await self._run_component(name, components_inputs[name], parent_span=span)
res: Dict[str, Any] = await self._run_component(
name, components_inputs[name], parent_span=span
)
yield {name: deepcopy(res)}

# Delete the inputs that were consumed by the Component and are not received from the user
sockets = list(components_inputs[name].keys())
for socket_name in sockets:
senders = comp.__haystack_input__._sockets_dict[socket_name].senders
senders = comp.__haystack_input__._sockets_dict[
socket_name
].senders
if senders:
# Delete all inputs that are received from other Components
del components_inputs[name][socket_name]
Expand All @@ -494,10 +540,14 @@ def run(self, word: str):
# This happens when a component was put in the waiting list but we reached it from another edge.
_dequeue_waiting_component((name, comp), waiting_queue)

for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs):
for pair in self._find_components_that_will_receive_no_input(
name, res, components_inputs
):
_dequeue_component(pair, run_queue, waiting_queue)
receivers = self._find_receivers_from(name)
res = self._distribute_output(receivers, res, components_inputs, run_queue, waiting_queue)
res = self._distribute_output(
receivers, res, components_inputs, run_queue, waiting_queue
)

if len(res) > 0:
final_outputs[name] = res
Expand All @@ -523,15 +573,25 @@ def run(self, word: str):
warn(RuntimeWarning(msg))
break

(name, comp) = self._find_next_runnable_lazy_variadic_or_default_component(waiting_queue)
(name, comp) = (
self._find_next_runnable_lazy_variadic_or_default_component(
waiting_queue
)
)
_add_missing_input_defaults(name, comp, components_inputs)
_enqueue_component((name, comp), run_queue, waiting_queue)
continue

before_last_waiting_queue = last_waiting_queue.copy() if last_waiting_queue is not None else None
before_last_waiting_queue = (
last_waiting_queue.copy()
if last_waiting_queue is not None
else None
)
last_waiting_queue = {item[0] for item in waiting_queue}

(name, comp) = self._find_next_runnable_component(components_inputs, waiting_queue)
(name, comp) = self._find_next_runnable_component(
components_inputs, waiting_queue
)
_add_missing_input_defaults(name, comp, components_inputs)
_enqueue_component((name, comp), run_queue, waiting_queue)

Expand Down Expand Up @@ -567,7 +627,10 @@ async def run_async_pipeline(
outputs = [x async for x in pipeline.run(data)]

intermediate_outputs = {
k: v for d in outputs[:-1] for k, v in d.items() if include_outputs_from is None or k in include_outputs_from
k: v
for d in outputs[:-1]
for k, v in d.items()
if include_outputs_from is None or k in include_outputs_from
}
final_output = outputs[-1]

Expand Down

0 comments on commit 22fcc4b

Please sign in to comment.