From 9b7f9fb65d747bbab3b2183a07da1ce9979e7340 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Fri, 24 May 2024 20:30:43 +0200 Subject: [PATCH] fix: PromptNode run raises on empty inputs (#7734) * fix: PromptNode run raises on empty inputs * add reno * convert test into unit test --- haystack/nodes/prompt/prompt_node.py | 10 +++++----- ...promptnode-empty-inputs-c050c2040d489f9e.yaml | 4 ++++ test/prompt/test_prompt_node.py | 16 ++++++++++++++++ 3 files changed, 25 insertions(+), 5 deletions(-) create mode 100644 releasenotes/notes/fix-promptnode-empty-inputs-c050c2040d489f9e.yaml diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index d5186519fa..01e04a8811 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -252,19 +252,19 @@ def _prepare( # type: ignore else: invocation_context = invocation_context.copy() - if query and "query" not in invocation_context: + if query is not None and "query" not in invocation_context: invocation_context["query"] = query - if file_paths and "file_paths" not in invocation_context: + if file_paths is not None and "file_paths" not in invocation_context: invocation_context["file_paths"] = file_paths - if labels and "labels" not in invocation_context: + if labels is not None and "labels" not in invocation_context: invocation_context["labels"] = labels - if documents and "documents" not in invocation_context: + if documents is not None and "documents" not in invocation_context: invocation_context["documents"] = documents - if meta and "meta" not in invocation_context: + if meta is not None and "meta" not in invocation_context: invocation_context["meta"] = meta if "prompt_template" not in invocation_context: diff --git a/releasenotes/notes/fix-promptnode-empty-inputs-c050c2040d489f9e.yaml b/releasenotes/notes/fix-promptnode-empty-inputs-c050c2040d489f9e.yaml new file mode 100644 index 0000000000..bef5427a14 --- /dev/null +++ b/releasenotes/notes/fix-promptnode-empty-inputs-c050c2040d489f9e.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + When passing empty inputs (such as `query=""`) to PromptNode, the node would raise an error. This has been fixed. diff --git a/test/prompt/test_prompt_node.py b/test/prompt/test_prompt_node.py index 594a7fffd0..aba2ed8833 100644 --- a/test/prompt/test_prompt_node.py +++ b/test/prompt/test_prompt_node.py @@ -1212,3 +1212,19 @@ def test_prompt_no_truncation(mock_model, caplog): with caplog.at_level(logging.DEBUG): _ = node.prompt(PromptTemplate(prompt)) assert prompt in caplog.text + + +@pytest.mark.unit +def test_run_with_empty_inputs(): + mock_model = MagicMock(spec=PromptModel) + mock_model.invoke.return_value = ["mock answer"] + node = PromptNode(mock_model, default_prompt_template="question-answering") + result, _ = node.run(query="", documents=[]) + + # validate output variable present + assert "answers" in result + assert len(result["answers"]) == 1 + + # and that so-called invocation context contains the right keys + assert "invocation_context" in result + assert all(item in result["invocation_context"] for item in ["query", "documents", "answers", "prompts"])