diff --git a/libs/langchain/langchain/chains/llm_summarization_checker/base.py b/libs/langchain/langchain/chains/llm_summarization_checker/base.py index 2e58f9ef9e7a6..ab85dc8e167a8 100644 --- a/libs/langchain/langchain/chains/llm_summarization_checker/base.py +++ b/libs/langchain/langchain/chains/llm_summarization_checker/base.py @@ -17,18 +17,10 @@ PROMPTS_DIR = Path(__file__).parent / "prompts" -CREATE_ASSERTIONS_PROMPT = PromptTemplate.from_file( - PROMPTS_DIR / "create_facts.txt", ["summary"] -) -CHECK_ASSERTIONS_PROMPT = PromptTemplate.from_file( - PROMPTS_DIR / "check_facts.txt", ["assertions"] -) -REVISED_SUMMARY_PROMPT = PromptTemplate.from_file( - PROMPTS_DIR / "revise_summary.txt", ["checked_assertions", "summary"] -) -ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file( - PROMPTS_DIR / "are_all_true_prompt.txt", ["checked_assertions"] -) +CREATE_ASSERTIONS_PROMPT = PromptTemplate.from_file(PROMPTS_DIR / "create_facts.txt") +CHECK_ASSERTIONS_PROMPT = PromptTemplate.from_file(PROMPTS_DIR / "check_facts.txt") +REVISED_SUMMARY_PROMPT = PromptTemplate.from_file(PROMPTS_DIR / "revise_summary.txt") +ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file(PROMPTS_DIR / "are_all_true_prompt.txt") def _load_sequential_chain( diff --git a/libs/langchain/tests/unit_tests/chains/test_llm_summarization_checker.py b/libs/langchain/tests/unit_tests/chains/test_llm_summarization_checker.py index aa82cead6bee4..ff1b457cd2069 100644 --- a/libs/langchain/tests/unit_tests/chains/test_llm_summarization_checker.py +++ b/libs/langchain/tests/unit_tests/chains/test_llm_summarization_checker.py @@ -14,6 +14,13 @@ from tests.unit_tests.llms.fake_llm import FakeLLM +def test_input_variables() -> None: + assert CREATE_ASSERTIONS_PROMPT.input_variables == ["summary"] + assert CHECK_ASSERTIONS_PROMPT.input_variables == ["assertions"] + assert REVISED_SUMMARY_PROMPT.input_variables == ["checked_assertions", "summary"] + assert ARE_ALL_TRUE_PROMPT.input_variables == ["checked_assertions"] + + @pytest.fixture def fake_llm_summarization_checker_chain() -> LLMSummarizationCheckerChain: """Fake LLMCheckerChain for testing."""