Skip to content

Commit

Permalink
Remove unnecessary Connection class (deepset-ai#6842)
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza authored Jan 29, 2024
1 parent b1ec32d commit 9211f53
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 283 deletions.
174 changes: 0 additions & 174 deletions haystack/core/component/connection.py

This file was deleted.

157 changes: 123 additions & 34 deletions haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,20 @@
#
# SPDX-License-Identifier: Apache-2.0
import importlib
import itertools
import logging
from collections import defaultdict
from copy import copy
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union

import networkx # type:ignore

from haystack.core.component import Component, InputSocket, OutputSocket, component
from haystack.core.component.connection import Connection, parse_connect_string
from haystack.core.errors import PipelineConnectError, PipelineError, PipelineRuntimeError, PipelineValidationError
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
from haystack.core.pipeline.draw.draw import RenderingEngines, _draw
from haystack.core.serialization import component_from_dict, component_to_dict
from haystack.core.type_utils import _type_name
from haystack.core.type_utils import _type_name, _types_are_compatible

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -52,8 +51,6 @@ def __init__(
self.metadata = metadata or {}
self.max_loops_allowed = max_loops_allowed
self.graph = networkx.MultiDiGraph()
self._connections: List[Connection] = []
self._mandatory_connections: Dict[str, List[Connection]] = defaultdict(list)
self._debug: Dict[int, Dict[str, Any]] = {}
self._debug_path = Path(debug_path)

Expand Down Expand Up @@ -264,42 +261,99 @@ def connect(self, connect_from: str, connect_to: str) -> None:
[receiver_socket] if receiver_socket else list(to_sockets.values())
)

connection = Connection.from_list_of_sockets(
sender, sender_socket_candidates, receiver, receiver_socket_candidates
# Find all possible connections between these two components
possible_connections = [
(sender_sock, receiver_sock)
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates)
if _types_are_compatible(sender_sock.type, receiver_sock.type)
]

# We need this status for error messages, since we might need it in multiple places we calculate it here
status = _connections_status(
sender_node=sender,
sender_sockets=sender_socket_candidates,
receiver_node=receiver,
receiver_sockets=receiver_socket_candidates,
)

if not possible_connections:
# There's no possible connection between these two components
if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1:
msg = (
f"Cannot connect '{sender}.{sender_socket_candidates[0].name}' with '{receiver}.{receiver_socket_candidates[0].name}': "
f"their declared input and output types do not match.\n{status}"
)
else:
msg = f"Cannot connect '{sender}' with '{receiver}': " f"no matching connections available.\n{status}"
raise PipelineConnectError(msg)

if len(possible_connections) == 1:
# There's only one possible connection, use it
sender_socket = possible_connections[0][0]
receiver_socket = possible_connections[0][1]

if len(possible_connections) > 1:
# There are multiple possible connection, let's try to match them by name
name_matches = [
(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
]
if len(name_matches) != 1:
# There's are either no matches or more than one, we can't pick one reliably
msg = (
f"Cannot connect '{sender}' with '{receiver}': more than one connection is possible "
"between these components. Please specify the connection name, like: "
f"pipeline.connect('{sender}.{possible_connections[0][0].name}', "
f"'{receiver}.{possible_connections[0][1].name}').\n{status}"
)
raise PipelineConnectError(msg)

# Get the only possible match
sender_socket = name_matches[0][0]
receiver_socket = name_matches[0][1]

# Connection must be valid on both sender/receiver sides
if (
not connection.sender_socket
or not connection.receiver_socket
or not connection.sender
or not connection.receiver
):
raise PipelineConnectError(f"Connection must have both sender and receiver: {connection}")

# Create the connection
logger.debug(
"Connecting '%s.%s' to '%s.%s'",
connection.sender,
connection.sender_socket.name,
connection.receiver,
connection.receiver_socket.name,
)
if not sender_socket or not receiver_socket or not sender or not receiver:
if sender and sender_socket:
sender_repr = f"{sender}.{sender_socket.name} ({_type_name(sender_socket.type)})"
else:
sender_repr = "input needed"

if receiver and receiver_socket:
receiver_repr = f"({_type_name(receiver_socket.type)}) {receiver}.{receiver_socket.name}"
else:
receiver_repr = "output"
msg = f"Connection must have both sender and receiver: {sender_repr} -> {receiver_repr}"
raise PipelineConnectError(msg)

logger.debug("Connecting '%s.%s' to '%s.%s'", sender, sender_socket.name, receiver, receiver_socket.name)

if receiver in sender_socket.receivers and sender in receiver_socket.senders:
# This is already connected, nothing to do
return

if receiver_socket.senders and not receiver_socket.is_variadic:
# Only variadic input sockets can receive from multiple senders
msg = (
f"Cannot connect '{sender}.{sender_socket.name}' with '{receiver}.{receiver_socket.name}': "
f"{receiver}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
)
raise PipelineConnectError(msg)

# Update the sockets with the new connection
sender_socket.receivers.append(receiver)
receiver_socket.senders.append(sender)

# Create the new connection
self.graph.add_edge(
connection.sender,
connection.receiver,
key=f"{connection.sender_socket.name}/{connection.receiver_socket.name}",
conn_type=_type_name(connection.sender_socket.type),
from_socket=connection.sender_socket,
to_socket=connection.receiver_socket,
mandatory=connection.is_mandatory,
sender,
receiver,
key=f"{sender_socket.name}/{receiver_socket.name}",
conn_type=_type_name(sender_socket.type),
from_socket=sender_socket,
to_socket=receiver_socket,
mandatory=receiver_socket.is_mandatory,
)

self._connections.append(connection)
if connection.is_mandatory:
self._mandatory_connections[connection.receiver].append(connection)

def get_component(self, name: str) -> Component:
"""
Returns an instance of a component.
Expand Down Expand Up @@ -648,3 +702,38 @@ def run( # noqa: C901, PLR0912 pylint: disable=too-many-branches
to_run.append((name, comp))

return final_outputs


def _connections_status(
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
):
"""
Lists the status of the sockets, for error messages.
"""
sender_sockets_entries = []
for sender_socket in sender_sockets:
sender_sockets_entries.append(f" - {sender_socket.name}: {_type_name(sender_socket.type)}")
sender_sockets_list = "\n".join(sender_sockets_entries)

receiver_sockets_entries = []
for receiver_socket in receiver_sockets:
if receiver_socket.senders:
sender_status = f"sent by {','.join(receiver_socket.senders)}"
else:
sender_status = "available"
receiver_sockets_entries.append(
f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
)
receiver_sockets_list = "\n".join(receiver_sockets_entries)

return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"


def parse_connect_string(connection: str) -> Tuple[str, Optional[str]]:
"""
Returns component-connection pairs from a connect_to/from string
"""
if "." in connection:
split_str = connection.split(".", maxsplit=1)
return (split_str[0], split_str[1])
return connection, None
Loading

0 comments on commit 9211f53

Please sign in to comment.