Skip to content

Commit

Permalink
fix: PromptNode run raises on empty inputs (#7734)
Browse files Browse the repository at this point in the history
* fix: PromptNode run raises on empty inputs

* add reno

* convert test into unit test
  • Loading branch information
tstadel authored May 24, 2024
1 parent 97596c2 commit 9b7f9fb
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
10 changes: 5 additions & 5 deletions haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,19 +252,19 @@ def _prepare( # type: ignore
else:
invocation_context = invocation_context.copy()

if query and "query" not in invocation_context:
if query is not None and "query" not in invocation_context:
invocation_context["query"] = query

if file_paths and "file_paths" not in invocation_context:
if file_paths is not None and "file_paths" not in invocation_context:
invocation_context["file_paths"] = file_paths

if labels and "labels" not in invocation_context:
if labels is not None and "labels" not in invocation_context:
invocation_context["labels"] = labels

if documents and "documents" not in invocation_context:
if documents is not None and "documents" not in invocation_context:
invocation_context["documents"] = documents

if meta and "meta" not in invocation_context:
if meta is not None and "meta" not in invocation_context:
invocation_context["meta"] = meta

if "prompt_template" not in invocation_context:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
When passing empty inputs (such as `query=""`) to PromptNode, the node would raise an error. This has been fixed.
16 changes: 16 additions & 0 deletions test/prompt/test_prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,3 +1212,19 @@ def test_prompt_no_truncation(mock_model, caplog):
with caplog.at_level(logging.DEBUG):
_ = node.prompt(PromptTemplate(prompt))
assert prompt in caplog.text


@pytest.mark.unit
def test_run_with_empty_inputs():
mock_model = MagicMock(spec=PromptModel)
mock_model.invoke.return_value = ["mock answer"]
node = PromptNode(mock_model, default_prompt_template="question-answering")
result, _ = node.run(query="", documents=[])

# validate output variable present
assert "answers" in result
assert len(result["answers"]) == 1

# and that so-called invocation context contains the right keys
assert "invocation_context" in result
assert all(item in result["invocation_context"] for item in ["query", "documents", "answers", "prompts"])

0 comments on commit 9b7f9fb

Please sign in to comment.