diff --git a/haystack/components/builders/chat_prompt_builder.py b/haystack/components/builders/chat_prompt_builder.py index a8824dffca..d3be3f8059 100644 --- a/haystack/components/builders/chat_prompt_builder.py +++ b/haystack/components/builders/chat_prompt_builder.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from copy import deepcopy -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Literal, Optional, Set, Union from jinja2 import meta from jinja2.sandbox import SandboxedEnvironment @@ -100,7 +100,7 @@ class ChatPromptBuilder: def __init__( self, template: Optional[List[ChatMessage]] = None, - required_variables: Optional[List[str]] = None, + required_variables: Optional[Union[List[str], Literal["*"]]] = None, variables: Optional[List[str]] = None, ): """ @@ -112,7 +112,8 @@ def __init__( the `init` method` or the `run` method. :param required_variables: List variables that must be provided as input to ChatPromptBuilder. - If a variable listed as required is not provided, an exception is raised. Optional. + If a variable listed as required is not provided, an exception is raised. + If set to "*", all variables found in the prompt are required. Optional. :param variables: List input variables to use in prompt templates instead of the ones inferred from the `template` parameter. For example, to use more variables during prompt engineering than the ones present @@ -127,14 +128,15 @@ def __init__( if template and not variables: for message in template: if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM): - # infere variables from template + # infer variables from template ast = self._env.parse(message.content) template_variables = meta.find_undeclared_variables(ast) variables += list(template_variables) + self.variables = variables # setup inputs - for var in variables: - if var in self.required_variables: + for var in self.variables: + if self.required_variables == "*" or var in self.required_variables: component.set_input_type(self, var, Any) else: component.set_input_type(self, var, Any, "") @@ -211,12 +213,16 @@ def _validate_variables(self, provided_variables: Set[str]): :raises ValueError: If no template is provided or if all the required template variables are not provided. """ - missing_variables = [var for var in self.required_variables if var not in provided_variables] + if self.required_variables == "*": + required_variables = sorted(self.variables) + else: + required_variables = self.required_variables + missing_variables = [var for var in required_variables if var not in provided_variables] if missing_variables: missing_vars_str = ", ".join(missing_variables) raise ValueError( f"Missing required input variables in ChatPromptBuilder: {missing_vars_str}. " - f"Required variables: {self.required_variables}. Provided variables: {provided_variables}." + f"Required variables: {required_variables}. Provided variables: {provided_variables}." ) def to_dict(self) -> Dict[str, Any]: diff --git a/haystack/components/builders/prompt_builder.py b/haystack/components/builders/prompt_builder.py index 3cb29e1211..d5b52cd42a 100644 --- a/haystack/components/builders/prompt_builder.py +++ b/haystack/components/builders/prompt_builder.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Literal, Optional, Set, Union from jinja2 import meta from jinja2.sandbox import SandboxedEnvironment @@ -137,7 +137,10 @@ class PromptBuilder: """ def __init__( - self, template: str, required_variables: Optional[List[str]] = None, variables: Optional[List[str]] = None + self, + template: str, + required_variables: Optional[Union[List[str], Literal["*"]]] = None, + variables: Optional[List[str]] = None, ): """ Constructs a PromptBuilder component. @@ -150,7 +153,8 @@ def __init__( unless explicitly specified. If an optional variable is not provided, it's replaced with an empty string in the rendered prompt. :param required_variables: List variables that must be provided as input to PromptBuilder. - If a variable listed as required is not provided, an exception is raised. Optional. + If a variable listed as required is not provided, an exception is raised. + If set to "*", all variables found in the prompt are required. Optional. :param variables: List input variables to use in prompt templates instead of the ones inferred from the `template` parameter. For example, to use more variables during prompt engineering than the ones present @@ -173,12 +177,12 @@ def __init__( ast = self._env.parse(template) template_variables = meta.find_undeclared_variables(ast) variables = list(template_variables) - variables = variables or [] + self.variables = variables # setup inputs - for var in variables: - if var in self.required_variables: + for var in self.variables: + if self.required_variables == "*" or var in self.required_variables: component.set_input_type(self, var, Any) else: component.set_input_type(self, var, Any, "") @@ -238,10 +242,14 @@ def _validate_variables(self, provided_variables: Set[str]): :raises ValueError: If any of the required template variables is not provided. """ - missing_variables = [var for var in self.required_variables if var not in provided_variables] + if self.required_variables == "*": + required_variables = sorted(self.variables) + else: + required_variables = self.required_variables + missing_variables = [var for var in required_variables if var not in provided_variables] if missing_variables: missing_vars_str = ", ".join(missing_variables) raise ValueError( f"Missing required input variables in PromptBuilder: {missing_vars_str}. " - f"Required variables: {self.required_variables}. Provided variables: {provided_variables}." + f"Required variables: {required_variables}. Provided variables: {provided_variables}." ) diff --git a/releasenotes/notes/prompt-builder-require-all-variables-007f87e9c7c89e8c.yaml b/releasenotes/notes/prompt-builder-require-all-variables-007f87e9c7c89e8c.yaml new file mode 100644 index 0000000000..cc899fff99 --- /dev/null +++ b/releasenotes/notes/prompt-builder-require-all-variables-007f87e9c7c89e8c.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Added a new option to the required_variables parameter to the PromptBuilder and ChatPromptBuilder. + By passing `required_variables="*"` you can automatically set all variables in the prompt to be required. diff --git a/test/components/builders/test_chat_prompt_builder.py b/test/components/builders/test_chat_prompt_builder.py index e40b47b508..cd5c2c62c8 100644 --- a/test/components/builders/test_chat_prompt_builder.py +++ b/test/components/builders/test_chat_prompt_builder.py @@ -137,6 +137,17 @@ def test_run_with_missing_required_input(self): with pytest.raises(ValueError, match="foo, bar"): builder.run() + def test_run_with_missing_required_input_using_star(self): + builder = ChatPromptBuilder( + template=[ChatMessage.from_user("This is a {{ foo }}, not a {{ bar }}")], required_variables="*" + ) + with pytest.raises(ValueError, match="foo"): + builder.run(bar="bar") + with pytest.raises(ValueError, match="bar"): + builder.run(foo="foo") + with pytest.raises(ValueError, match="bar, foo"): + builder.run() + def test_run_with_variables(self): variables = ["var1", "var2", "var3"] template = [ChatMessage.from_user("Hello, {{ name }}! {{ var1 }}")] diff --git a/test/components/builders/test_prompt_builder.py b/test/components/builders/test_prompt_builder.py index 7461327f4e..39c23735a5 100644 --- a/test/components/builders/test_prompt_builder.py +++ b/test/components/builders/test_prompt_builder.py @@ -143,6 +143,15 @@ def test_run_with_missing_required_input(self): with pytest.raises(ValueError, match="foo, bar"): builder.run() + def test_run_with_missing_required_input_using_star(self): + builder = PromptBuilder(template="This is a {{ foo }}, not a {{ bar }}", required_variables="*") + with pytest.raises(ValueError, match="foo"): + builder.run(bar="bar") + with pytest.raises(ValueError, match="bar"): + builder.run(foo="foo") + with pytest.raises(ValueError, match="bar, foo"): + builder.run() + def test_run_with_variables(self): variables = ["var1", "var2", "var3"] template = "Hello, {{ name }}! {{ var1 }}" @@ -296,7 +305,7 @@ def test_date_with_addition_offset(self) -> None: assert now_plus_2 == result - def test_date_with_substraction_offset(self) -> None: + def test_date_with_subtraction_offset(self) -> None: template = "Time after 12 days is: {% now 'UTC' - 'days=12' %}" builder = PromptBuilder(template=template)