diff --git a/libs/core/langchain_core/prompts/loading.py b/libs/core/langchain_core/prompts/loading.py index 3d65f8323f1b8..e224a05ae9f6f 100644 --- a/libs/core/langchain_core/prompts/loading.py +++ b/libs/core/langchain_core/prompts/loading.py @@ -186,15 +186,21 @@ def _load_prompt_from_file( def _load_chat_prompt(config: dict) -> ChatPromptTemplate: """Load chat prompt from config""" - messages = config.pop("messages") - template = messages[0]["prompt"].pop("template") if messages else None - config.pop("input_variables") + template = config.pop("template") if not template: msg = "Can't load chat prompt without template" raise ValueError(msg) - return ChatPromptTemplate.from_template(template=template, **config) + messages = [] + if isinstance(template, str): + messages.append(("human", template)) + + elif isinstance(template, list): + for item in template: + messages.append((item["role"], item["content"])) + + return ChatPromptTemplate(messages=messages, **config) type_to_loader_dict: dict[str, Callable[[dict], BasePromptTemplate]] = { diff --git a/libs/core/tests/unit_tests/examples/simple_chat_prompt.json b/libs/core/tests/unit_tests/examples/simple_chat_prompt.json new file mode 100644 index 0000000000000..63761fcd012a3 --- /dev/null +++ b/libs/core/tests/unit_tests/examples/simple_chat_prompt.json @@ -0,0 +1,19 @@ +{ + "_type": "chat", + "input_variables": [ + "adjective" + ], + "partial_variables": { + "content": "dogs" + }, + "template": [ + { + "role": "system", + "content": "You are a comedian" + }, + { + "role": "human", + "content": "Tell me a {adjective} joke about {content}." + } + ] +} \ No newline at end of file diff --git a/libs/core/tests/unit_tests/examples/simple_chat_prompt.yaml b/libs/core/tests/unit_tests/examples/simple_chat_prompt.yaml new file mode 100644 index 0000000000000..ae6a0a90e9467 --- /dev/null +++ b/libs/core/tests/unit_tests/examples/simple_chat_prompt.yaml @@ -0,0 +1,10 @@ +_type: chat +input_variables: + ["adjective"] +partial_variables: + content: dogs +template: + - role: system + content: "You are a comedian" + - role: human + content: "Tell me a {adjective} joke about {content}." \ No newline at end of file diff --git a/libs/core/tests/unit_tests/prompts/test_loading.py b/libs/core/tests/unit_tests/prompts/test_loading.py index 2a98a1e95ce90..752cbc58d151f 100644 --- a/libs/core/tests/unit_tests/prompts/test_loading.py +++ b/libs/core/tests/unit_tests/prompts/test_loading.py @@ -7,6 +7,7 @@ import pytest +from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.prompts.few_shot import FewShotPromptTemplate from langchain_core.prompts.loading import load_prompt from langchain_core.prompts.prompt import PromptTemplate @@ -36,6 +37,20 @@ def test_loading_from_yaml() -> None: assert prompt == expected_prompt +def test_loading_chat_from_yaml() -> None: + """Test loading from yaml file.""" + prompt = load_prompt(EXAMPLE_DIR / "simple_chat_prompt.yaml") + expected_prompt = ChatPromptTemplate( + input_variables=["adjective"], + partial_variables={"content": "dogs"}, + messages=[ + ("system", "You are a comedian"), + ("human", "Tell me a {adjective} joke about {content}."), + ], + ) + assert prompt == expected_prompt + + def test_loading_from_json() -> None: """Test loading from json file.""" prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.json") @@ -46,6 +61,20 @@ def test_loading_from_json() -> None: assert prompt == expected_prompt +def test_loading_chat_from_json() -> None: + """Test loading from json file.""" + prompt = load_prompt(EXAMPLE_DIR / "simple_chat_prompt.json") + expected_prompt = ChatPromptTemplate( + input_variables=["adjective"], + partial_variables={"content": "dogs"}, + messages=[ + ("system", "You are a comedian"), + ("human", "Tell me a {adjective} joke about {content}."), + ], + ) + assert prompt == expected_prompt + + def test_loading_jinja_from_json() -> None: """Test that loading jinja2 format prompts from JSON raises ValueError.""" prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.json"