Skip to content

Commit

Permalink
feat: Add route output type validation in ConditionalRouter (#8500)
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza authored Oct 29, 2024
1 parent 33675b4 commit 8a35e79
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 38 deletions.
116 changes: 96 additions & 20 deletions haystack/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import ast
import contextlib
from typing import Any, Callable, Dict, List, Optional, Set
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Union, get_args, get_origin
from warnings import warn

from jinja2 import Environment, TemplateSyntaxError, meta
Expand Down Expand Up @@ -107,7 +107,13 @@ class ConditionalRouter:
```
"""

def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None, unsafe: bool = False):
def __init__(
self,
routes: List[Dict],
custom_filters: Optional[Dict[str, Callable]] = None,
unsafe: bool = False,
validate_output_type: bool = False,
):
"""
Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
Expand All @@ -127,10 +133,14 @@ def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callab
:param unsafe:
Enable execution of arbitrary code in the Jinja template.
This should only be used if you trust the source of the template as it can be lead to remote code execution.
:param validate_output_type:
Enable validation of routes' output.
If a route output doesn't match the declared type a ValueError is raised running.
"""
self.routes: List[dict] = routes
self.custom_filters = custom_filters or {}
self._unsafe = unsafe
self._validate_output_type = validate_output_type

# Create a Jinja environment to inspect variables in the condition templates
if self._unsafe:
Expand Down Expand Up @@ -170,7 +180,13 @@ def to_dict(self) -> Dict[str, Any]:
# output_type needs to be serialized to a string
route["output_type"] = serialize_type(route["output_type"])
se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
return default_to_dict(self, routes=self.routes, custom_filters=se_filters, unsafe=self._unsafe)
return default_to_dict(
self,
routes=self.routes,
custom_filters=se_filters,
unsafe=self._unsafe,
validate_output_type=self._validate_output_type,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
Expand Down Expand Up @@ -210,9 +226,12 @@ def run(self, **kwargs):
:returns: A dictionary where the key is the `output_name` of the selected route and the value is the `output`
of the selected route.
:raises NoRouteSelectedException: If no `condition' in the routes is `True`.
:raises RouteConditionException: If there is an error parsing or evaluating the `condition` expression in the
routes.
:raises NoRouteSelectedException:
If no `condition' in the routes is `True`.
:raises RouteConditionException:
If there is an error parsing or evaluating the `condition` expression in the routes.
:raises ValueError:
If type validation is enabled and route type doesn't match actual value type.
"""
# Create a Jinja native environment to evaluate the condition templates as Python expressions
for route in self.routes:
Expand All @@ -221,21 +240,28 @@ def run(self, **kwargs):
rendered = t.render(**kwargs)
if not self._unsafe:
rendered = ast.literal_eval(rendered)
if rendered:
# We now evaluate the `output` expression to determine the route output
t_output = self._env.from_string(route["output"])
output = t_output.render(**kwargs)
# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
if not self._unsafe:
output = ast.literal_eval(output)
# and return the output as a dictionary under the output_name key
return {route["output_name"]: output}
if not rendered:
continue
# We now evaluate the `output` expression to determine the route output
t_output = self._env.from_string(route["output"])
output = t_output.render(**kwargs)
# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
if not self._unsafe:
output = ast.literal_eval(output)
except Exception as e:
raise RouteConditionException(f"Error evaluating condition for route '{route}': {e}") from e
msg = f"Error evaluating condition for route '{route}': {e}"
raise RouteConditionException(msg) from e

if self._validate_output_type and not self._output_matches_type(output, route["output_type"]):
msg = f"""Route '{route["output_name"]}' type doesn't match expected type"""
raise ValueError(msg)

# and return the output as a dictionary under the output_name key
return {route["output_name"]: output}

raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")

Expand Down Expand Up @@ -288,3 +314,53 @@ def _validate_template(self, env: Environment, template_text: str):
return True
except TemplateSyntaxError:
return False

def _output_matches_type(self, value: Any, expected_type: type): # noqa: PLR0911 # pylint: disable=too-many-return-statements
"""
Checks whether `value` type matches the `expected_type`.
"""
# Handle Any type
if expected_type is Any:
return True

# Get the origin type (List, Dict, etc) and type arguments
origin = get_origin(expected_type)
args = get_args(expected_type)

# Handle basic types (int, str, etc)
if origin is None:
return isinstance(value, expected_type)

# Handle Sequence types (List, Tuple, etc)
if isinstance(origin, type) and issubclass(origin, Sequence):
if not isinstance(value, Sequence):
return False
# Empty sequence is valid
if not value:
return True
# Check each element against the sequence's type parameter
return all(self._output_matches_type(item, args[0]) for item in value)

# Handle basic types (int, str, etc)
if origin is None:
return isinstance(value, expected_type)

# Handle Mapping types (Dict, etc)
if isinstance(origin, type) and issubclass(origin, Mapping):
if not isinstance(value, Mapping):
return False
# Empty mapping is valid
if not value:
return True
key_type, value_type = args
# Check all keys and values match their respective types
return all(
self._output_matches_type(k, key_type) and self._output_matches_type(v, value_type)
for k, v in value.items()
)

# Handle Union types (including Optional)
if origin is Union:
return any(self._output_matches_type(value, arg) for arg in args)

return False
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
enhancements:
- |
Add output type validation in `ConditionalRouter`.
Setting `validate_output_type` to `True` will enable a check to verify if
the actual output of a route returns the declared type.
If it doesn't match a `ValueError` is raised.
75 changes: 57 additions & 18 deletions test/components/routers/test_conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,6 @@ def custom_filter_to_sede(value):


class TestRouter:
@pytest.fixture
def routes(self):
return [
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]

@pytest.fixture
def router(self, routes):
return ConditionalRouter(routes)

def test_missing_mandatory_fields(self):
"""
Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys
Expand Down Expand Up @@ -90,7 +74,16 @@ def test_mandatory_and_optional_fields_with_extra_fields(self):
with pytest.raises(ValueError):
ConditionalRouter(routes)

def test_router_initialized(self, routes):
def test_router_initialized(self):
routes = [
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
router = ConditionalRouter(routes)

assert router.routes == routes
Expand Down Expand Up @@ -166,7 +159,7 @@ def test_complex_condition(self):
result = router.run(messages=[message], streams=[1, 2, 3], query="my query")
assert result == {"streams": [1, 2, 3]}

def test_router_no_route(self, router):
def test_router_no_route(self):
# should raise an exception
router = ConditionalRouter(
[
Expand Down Expand Up @@ -358,3 +351,49 @@ def test_unsafe(self):
message = ChatMessage.from_user(content="This is a message")
res = router.run(streams=streams, message=message)
assert res == {"message": message}

def test_validate_output_type_without_unsafe(self):
routes = [
{
"condition": "{{streams|length < 2}}",
"output": "{{message}}",
"output_type": ChatMessage,
"output_name": "message",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
router = ConditionalRouter(routes, validate_output_type=True)
streams = [1]
message = ChatMessage.from_user(content="This is a message")
with pytest.raises(ValueError, match="Route 'message' type doesn't match expected type"):
router.run(streams=streams, message=message)

def test_validate_output_type_with_unsafe(self):
routes = [
{
"condition": "{{streams|length < 2}}",
"output": "{{message}}",
"output_type": ChatMessage,
"output_name": "message",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
router = ConditionalRouter(routes, unsafe=True, validate_output_type=True)
streams = [1]
message = ChatMessage.from_user(content="This is a message")
res = router.run(streams=streams, message=message)
assert isinstance(res["message"], ChatMessage)

streams = ["1", "2", "3", "4"]
with pytest.raises(ValueError, match="Route 'streams' type doesn't match expected type"):
router.run(streams=streams, message=message)

0 comments on commit 8a35e79

Please sign in to comment.