From 76858e3a0729af8b9a985e0b2beb5b3217e44905 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 21 Aug 2024 16:28:03 -0700 Subject: [PATCH 1/6] wip --- libs/core/langchain_core/prompts/chat.py | 94 ++-------- libs/core/langchain_core/prompts/message.py | 166 ++++++++++++++++++ .../core/langchain_core/prompts/structured.py | 4 +- .../tests/unit_tests/prompts/test_chat.py | 6 +- libs/langchain/langchain/prompts/chat.py | 4 +- 5 files changed, 188 insertions(+), 86 deletions(-) create mode 100644 libs/core/langchain_core/prompts/message.py diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index a53f1bb251cd0..fc66bbef41efa 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -22,7 +22,6 @@ ) from langchain_core._api import deprecated -from langchain_core.load import Serializable from langchain_core.messages import ( AIMessage, AnyMessage, @@ -36,6 +35,10 @@ from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts.image import ImagePromptTemplate +from langchain_core.prompts.message import ( + BaseMessagePromptTemplate, + _MessageDictPromptTemplate, +) from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import StringPromptTemplate, get_template_variables from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator @@ -43,80 +46,6 @@ from langchain_core.utils.interactive_env import is_interactive_env -class BaseMessagePromptTemplate(Serializable, ABC): - """Base class for message prompt templates.""" - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether or not the class is serializable. - Returns: True""" - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" - return ["langchain", "prompts", "chat"] - - @abstractmethod - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - """Format messages from kwargs. Should return a list of BaseMessages. - - Args: - **kwargs: Keyword arguments to use for formatting. - - Returns: - List of BaseMessages. - """ - - async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: - """Async format messages from kwargs. - Should return a list of BaseMessages. - - Args: - **kwargs: Keyword arguments to use for formatting. - - Returns: - List of BaseMessages. - """ - return self.format_messages(**kwargs) - - @property - @abstractmethod - def input_variables(self) -> List[str]: - """Input variables for this prompt template. - - Returns: - List of input variables. - """ - - def pretty_repr(self, html: bool = False) -> str: - """Human-readable representation. - - Args: - html: Whether to format as HTML. Defaults to False. - - Returns: - Human-readable representation. - """ - raise NotImplementedError - - def pretty_print(self) -> None: - """Print a human-readable representation.""" - print(self.pretty_repr(html=is_interactive_env())) # noqa: T201 - - def __add__(self, other: Any) -> ChatPromptTemplate: - """Combine two prompt templates. - - Args: - other: Another prompt template. - - Returns: - Combined prompt template. - """ - prompt = ChatPromptTemplate(messages=[self]) # type: ignore[call-arg] - return prompt + other - - class MessagesPlaceholder(BaseMessagePromptTemplate): """Prompt template that assumes variable is already list of messages. @@ -982,7 +911,8 @@ def __init__( """ _messages = [ - _convert_to_message(message, template_format) for message in messages + _convert_to_message_template(message, template_format) + for message in messages ] # Automatically infer input variables from messages @@ -1279,7 +1209,7 @@ def append(self, message: MessageLikeRepresentation) -> None: Args: message: representation of a message to append. """ - self.messages.append(_convert_to_message(message)) + self.messages.append(_convert_to_message_template(message)) def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None: """Extend the chat template with a sequence of messages. @@ -1287,7 +1217,9 @@ def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None: Args: messages: sequence of message representations to append. """ - self.messages.extend([_convert_to_message(message) for message in messages]) + self.messages.extend( + [_convert_to_message_template(message) for message in messages] + ) @overload def __getitem__(self, index: int) -> MessageLike: ... @@ -1404,7 +1336,7 @@ def _create_template_from_message_type( return message -def _convert_to_message( +def _convert_to_message_template( message: MessageLikeRepresentation, template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", ) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]: @@ -1453,6 +1385,10 @@ def _convert_to_message( cast(str, template), template_format=template_format ) ) + elif isinstance(message, dict): + _message = _MessageDictPromptTemplate( + template=message, template_format=template_format + ) else: raise NotImplementedError(f"Unsupported message type: {type(message)}") diff --git a/libs/core/langchain_core/prompts/message.py b/libs/core/langchain_core/prompts/message.py new file mode 100644 index 0000000000000..a247fffc7f429 --- /dev/null +++ b/libs/core/langchain_core/prompts/message.py @@ -0,0 +1,166 @@ +"""Message prompt templates.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence + +from langchain_core.load import Serializable +from langchain_core.messages import BaseMessage, convert_to_messages +from langchain_core.prompts.string import ( + DEFAULT_FORMATTER_MAPPING, + get_template_variables, +) +from langchain_core.utils.image import image_to_data_url +from langchain_core.utils.interactive_env import is_interactive_env + +if TYPE_CHECKING: + from langchain_core.prompts.chat import ChatPromptTemplate + + +class BaseMessagePromptTemplate(Serializable, ABC): + """Base class for message prompt templates.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return True if the class is serializable, else False. + + Returns: + True + """ + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "prompts", "chat"] + + @abstractmethod + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format messages from kwargs. Should return a list of BaseMessages. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + List of BaseMessages. + """ + + async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Async format messages from kwargs. + Should return a list of BaseMessages. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + List of BaseMessages. + """ + return self.format_messages(**kwargs) + + @property + @abstractmethod + def input_variables(self) -> List[str]: + """Input variables for this prompt template. + + Returns: + List of input variables. + """ + + def pretty_repr(self, html: bool = False) -> str: + """Human-readable representation. + + Args: + html: Whether to format as HTML. Defaults to False. + + Returns: + Human-readable representation. + """ + raise NotImplementedError + + def pretty_print(self) -> None: + """Print a human-readable representation.""" + print(self.pretty_repr(html=is_interactive_env())) # noqa: T201 + + def __add__(self, other: Any) -> ChatPromptTemplate: + """Combine two prompt templates. + + Args: + other: Another prompt template. + + Returns: + Combined prompt template. + """ + from langchain_core.prompts.chat import ChatPromptTemplate + + prompt = ChatPromptTemplate(messages=[self]) # type: ignore[call-arg] + return prompt + other + + +class _MessageDictPromptTemplate(BaseMessagePromptTemplate): + """Template represented by a dict that looks for input vars in all leaf vals. + + Special handling of any dict value that contains "type": "image_url". + """ + + template: Dict[str, Any] + template_format: Literal["f-string", "mustache"] + + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + msg_dict = _insert_input_variables(self.template, kwargs, self.template_format) + return convert_to_messages([msg_dict]) + + @property + def input_variables(self) -> List[str]: + return _get_input_variables(self.template, self.template_format) + + @property + def _prompt_type(self) -> str: + return "message-dict-prompt" + + +def _get_input_variables( + template: dict, template_format: Literal["f-string", "mustache"] +) -> List[str]: + input_variables = [] + for k, v in template.items(): + if isinstance(v, str): + input_variables += get_template_variables(v, template_format) + elif isinstance(v, dict): + input_variables += _get_input_variables(v, template_format) + elif isinstance(v, (list, tuple)): + for x in v: + if isinstance(x, str): + input_variables += get_template_variables(x, template_format) + elif isinstance(x, dict): + input_variables += _get_input_variables(x, template_format) + return list(set(input_variables)) + + +def _insert_input_variables( + template: Dict[str, Any], + inputs: Dict[str, Any], + template_format: Literal["f-string", "mustache"], +) -> Dict[str, Any]: + formatted = {} + formatter = DEFAULT_FORMATTER_MAPPING[template_format] + for k, v in template.items(): + if isinstance(v, str): + formatted[k] = formatter(v, **inputs) + elif isinstance(v, dict): + # Special handling for loading local images. + if k == "image_url" and "path" in v: + formatted_path = formatter(v.pop("path"), **inputs) + v["url"] = image_to_data_url(formatted_path) + formatted[k] = _insert_input_variables(v, inputs, template_format) + elif isinstance(v, (list, tuple)): + formatted_v = [] + for x in v: + if isinstance(x, str): + formatted_v.append(formatter(x, **inputs)) + elif isinstance(x, dict): + formatted_v.append( + _insert_input_variables(x, inputs, template_format) + ) + formatted[k] = type(v)(formatted_v) + return formatted diff --git a/libs/core/langchain_core/prompts/structured.py b/libs/core/langchain_core/prompts/structured.py index 5176b483d96da..0793b17a51e55 100644 --- a/libs/core/langchain_core/prompts/structured.py +++ b/libs/core/langchain_core/prompts/structured.py @@ -20,7 +20,7 @@ ChatPromptTemplate, MessageLikeRepresentation, MessagesPlaceholder, - _convert_to_message, + _convert_to_message_template, ) from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables.base import ( @@ -86,7 +86,7 @@ class OutputSchema(BaseModel): Returns: a structured prompt template """ - _messages = [_convert_to_message(message) for message in messages] + _messages = [_convert_to_message_template(message) for message in messages] # Automatically infer input variables from messages input_vars: Set[str] = set() diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index a280ff1dbb670..00f22672f4bd1 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -26,7 +26,7 @@ HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, - _convert_to_message, + _convert_to_message_template, ) from langchain_core.pydantic_v1 import ValidationError from tests.unit_tests.pydantic_utils import _schema @@ -432,7 +432,7 @@ def test_convert_to_message( args: Any, expected: Union[BaseMessage, BaseMessagePromptTemplate] ) -> None: """Test convert to message.""" - assert _convert_to_message(args) == expected + assert _convert_to_message_template(args) == expected def test_chat_prompt_template_indexing() -> None: @@ -477,7 +477,7 @@ def test_convert_to_message_is_strict() -> None: # meow does not correspond to a valid message type. # this test is here to ensure that functionality to interpret `meow` # as a role is NOT added. - _convert_to_message(("meow", "question")) + _convert_to_message_template(("meow", "question")) def test_chat_message_partial() -> None: diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index 35de60b41610f..445cba1e47e2b 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -12,7 +12,7 @@ MessagePromptTemplateT, MessagesPlaceholder, SystemMessagePromptTemplate, - _convert_to_message, + _convert_to_message_template, _create_template_from_message_type, ) @@ -28,7 +28,7 @@ "ChatPromptTemplate", "ChatPromptValue", "ChatPromptValueConcrete", - "_convert_to_message", + "_convert_to_message_template", "_create_template_from_message_type", "MessagePromptTemplateT", "MessageLike", From 27c8ba6ffcb5bd58b61e7727e67ca8ec96c06e49 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 22 Aug 2024 11:39:20 -0700 Subject: [PATCH 2/6] core[minor]: dict chat prompt template support --- libs/core/langchain_core/prompts/chat.py | 17 ++-- libs/core/langchain_core/prompts/message.py | 7 +- .../unit_tests/prompts/favicon-16x16.png | Bin 0 -> 542 bytes .../tests/unit_tests/prompts/test_chat.py | 85 ++++++++++++++++++ .../tests/unit_tests/prompts/test_message.py | 69 ++++++++++++++ 5 files changed, 170 insertions(+), 8 deletions(-) create mode 100644 libs/core/tests/unit_tests/prompts/favicon-16x16.png create mode 100644 libs/core/tests/unit_tests/prompts/test_message.py diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index fc66bbef41efa..77ed7c90a6375 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -37,7 +37,7 @@ from langchain_core.prompts.image import ImagePromptTemplate from langchain_core.prompts.message import ( BaseMessagePromptTemplate, - _MessageDictPromptTemplate, + _DictMessagePromptTemplate, ) from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import StringPromptTemplate, get_template_variables @@ -739,6 +739,7 @@ def pretty_print(self) -> None: Union[str, List[dict], List[object]], ], str, + Dict[str, Any], ] @@ -1039,7 +1040,7 @@ def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate: return cls.from_messages([message]) @classmethod - @deprecated("0.0.1", alternative="from_messages classmethod", pending=True) + @deprecated("0.0.1", alternative="from_messages", pending=True) def from_role_strings( cls, string_messages: List[Tuple[str, str]] ) -> ChatPromptTemplate: @@ -1059,7 +1060,7 @@ def from_role_strings( ) @classmethod - @deprecated("0.0.1", alternative="from_messages classmethod", pending=True) + @deprecated("0.0.1", alternative="from_messages", pending=True) def from_strings( cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]] ) -> ChatPromptTemplate: @@ -1386,8 +1387,14 @@ def _convert_to_message_template( ) ) elif isinstance(message, dict): - _message = _MessageDictPromptTemplate( - template=message, template_format=template_format + if template_format == "jinja": + raise ValueError( + f"{template_format} is unsafe and is not supported for templates " + f"expressed as dicts. Please use 'f-string' or 'mustache' format." + ) + _message = _DictMessagePromptTemplate( + template=message, + template_format=template_format, # type: ignore[arg-type] ) else: raise NotImplementedError(f"Unsupported message type: {type(message)}") diff --git a/libs/core/langchain_core/prompts/message.py b/libs/core/langchain_core/prompts/message.py index a247fffc7f429..73deb6f0d5240 100644 --- a/libs/core/langchain_core/prompts/message.py +++ b/libs/core/langchain_core/prompts/message.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, Literal from langchain_core.load import Serializable from langchain_core.messages import BaseMessage, convert_to_messages @@ -97,10 +97,11 @@ def __add__(self, other: Any) -> ChatPromptTemplate: return prompt + other -class _MessageDictPromptTemplate(BaseMessagePromptTemplate): +class _DictMessagePromptTemplate(BaseMessagePromptTemplate): """Template represented by a dict that looks for input vars in all leaf vals. - Special handling of any dict value that contains "type": "image_url". + Special handling of any dict value that contains + ``{"type": "image_url", "image_url": {"path": "..."}}`` """ template: Dict[str, Any] diff --git a/libs/core/tests/unit_tests/prompts/favicon-16x16.png b/libs/core/tests/unit_tests/prompts/favicon-16x16.png new file mode 100644 index 0000000000000000000000000000000000000000..c6c21a961b90c04c9e765202cca807b813659c0e GIT binary patch literal 542 zcmeAS@N?(olHy`uVBq!ia0vp^0wB!61|;P_|4#%`jKx9jP7LeL$-D$|Tv8)E(|mmy zw18|52FCVG1{RPKAeI7R1_tH@j10^`nh_+nfC(-uuz(rC1}QWNE&K#j*5T>m7-AuK zcCxj%a-hiZ`p*SM!tRUnIFdzD1;SDUFX=?@ym2y5F*+jp#xKSkoj+pJ>NvEmw+a~M zMwFJjmS3D9p<)oouGu&qV zJSwFv&1LBPs*i8;#O5&f1t-4Tv^`-eYaH{TKxmEN%;vZ6%&a)>{I~n~QDRY)O6>mB zGixV%a<1bN+W*h3ExIaJrp4)MQ+MXakJ-OjJQS#4|a>&I+b z8>gIn@9^-+nsdkN4NBN=&JtQ_u0MC)p(R4^&TDJ`xw+H*jzrG8S32{vBvd$@?2>>`l(jfo< literal 0 HcmV?d00001 diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 00f22672f4bd1..efc289795333b 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -13,6 +13,7 @@ BaseMessage, HumanMessage, SystemMessage, + ToolMessage, get_buffer_string, ) from langchain_core.prompt_values import ChatPromptValue @@ -29,8 +30,11 @@ _convert_to_message_template, ) from langchain_core.pydantic_v1 import ValidationError +from langchain_core.utils.image import image_to_data_url from tests.unit_tests.pydantic_utils import _schema +CUR_DIR = Path(__file__).parent.absolute().resolve() + @pytest.fixture def messages() -> List[BaseMessagePromptTemplate]: @@ -863,3 +867,84 @@ async def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None: ) assert dumpd(template) == snapshot() assert load(dumpd(template)) == template + + +def test_chat_tmpl_dict_msg() -> None: + template = ChatPromptTemplate( + [ + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "{text1}", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "image_url", "image_url": {"path": "{local_image_path}"}}, + ], + "name": "{name1}", + "tool_calls": [ + { + "name": "{tool_name1}", + "args": {"arg1": "{tool_arg1}"}, + "id": "1", + "type": "tool_call", + } + ], + }, + { + "role": "tool", + "content": "{tool_content2}", + "tool_call_id": "1", + "name": "{tool_name1}", + }, + ] + ) + image_path = str(CUR_DIR / "favicon-16x16.png") + image_url = image_to_data_url(image_path) + expected = [ + AIMessage( + [ + { + "type": "text", + "text": "important message", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + name="foo", + tool_calls=[ + { + "name": "do_stuff", + "args": {"arg1": "important arg1"}, + "id": "1", + "type": "tool_call", + } + ], + ), + ToolMessage("foo", name="do_stuff", tool_call_id="1"), + ] + + actual = template.invoke( + { + "local_image_path": image_path, + "text1": "important message", + "name1": "foo", + "tool_arg1": "important arg1", + "tool_name1": "do_stuff", + "tool_content2": "foo", + } + ).to_messages() + assert actual == expected + + partial_ = template.partial( **{ "local_image_path": image_path } ) + actual = partial_.invoke( + { + "text1": "important message", + "name1": "foo", + "tool_arg1": "important arg1", + "tool_name1": "do_stuff", + "tool_content2": "foo", + } + ).to_messages() + assert actual == expected diff --git a/libs/core/tests/unit_tests/prompts/test_message.py b/libs/core/tests/unit_tests/prompts/test_message.py new file mode 100644 index 0000000000000..aad3d82dfa354 --- /dev/null +++ b/libs/core/tests/unit_tests/prompts/test_message.py @@ -0,0 +1,69 @@ +from pathlib import Path + +from langchain_core.messages import AIMessage, BaseMessage, ToolMessage +from langchain_core.prompts.message import _DictMessagePromptTemplate +from langchain_core.utils.image import image_to_data_url + +CUR_DIR = Path(__file__).parent.absolute().resolve() + + +def test__dict_message_prompt_template_fstring() -> None: + template = { + "role": "assistant", + "content": [ + {"type": "text", "text": "{text1}", "cache_control": {"type": "ephemeral"}}, + {"type": "image_url", "image_url": {"path": "{local_image_path}"}}, + ], + "name": "{name1}", + "tool_calls": [ + { + "name": "{tool_name1}", + "args": {"arg1": "{tool_arg1}"}, + "id": "1", + "type": "tool_call", + } + ], + } + prompt = _DictMessagePromptTemplate(template=template, template_format="f-string") + image_path = str(CUR_DIR / "favicon-16x16.png") + image_url = image_to_data_url(image_path) + expected: BaseMessage = AIMessage( + [ + { + "type": "text", + "text": "important message", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + name="foo", + tool_calls=[ + { + "name": "do_stuff", + "args": {"arg1": "important arg1"}, + "id": "1", + "type": "tool_call", + } + ], + ) + actual = prompt.format_messages( + **{ + "local_image_path": image_path, + "text1": "important message", + "name1": "foo", + "tool_arg1": "important arg1", + "tool_name1": "do_stuff", + } + )[0] + assert actual == expected + + template = { + "role": "tool", + "content": "{content1}", + "tool_call_id": "1", + "name": "{name1}", + } + prompt = _DictMessagePromptTemplate(template=template, template_format="f-string") + expected = ToolMessage("foo", name="bar", tool_call_id="1") + actual = prompt.format_messages(**{"content1": "foo", "name1": "bar"})[0] + assert actual == expected From 10ef356ab40d169e08746521f07c8866bf282e69 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 22 Aug 2024 11:48:18 -0700 Subject: [PATCH 3/6] fmt --- libs/core/tests/unit_tests/prompts/test_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index efc289795333b..6e3e193a5f594 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -937,7 +937,7 @@ def test_chat_tmpl_dict_msg() -> None: ).to_messages() assert actual == expected - partial_ = template.partial( **{ "local_image_path": image_path } ) + partial_ = template.partial(**{"local_image_path": image_path}) actual = partial_.invoke( { "text1": "important message", From 100dc2e623414eb5e782a857c04242e662640bd7 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Fri, 25 Oct 2024 15:14:41 -0700 Subject: [PATCH 4/6] fmt --- libs/core/langchain_core/prompts/message.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/prompts/message.py b/libs/core/langchain_core/prompts/message.py index ba2450b0fcc8c..23ca8c0c2a83c 100644 --- a/libs/core/langchain_core/prompts/message.py +++ b/libs/core/langchain_core/prompts/message.py @@ -90,14 +90,16 @@ def __add__(self, other: Any) -> ChatPromptTemplate: Returns: Combined prompt template. """ + from langchain_core.prompts.chat import ChatPromptTemplate + prompt = ChatPromptTemplate(messages=[self]) # type: ignore[call-arg] return prompt + other class _DictMessagePromptTemplate(BaseMessagePromptTemplate): - """Template represented by a dict that looks for input vars in all leaf vals. + """Template represented by a dict that recursively fills input vars in string vals. - Special handling of any dict value that contains + Special handling of image_url dicts to load local paths. These look like: ``{"type": "image_url", "image_url": {"path": "..."}}`` """ From d7dd172b566586668b5e34c5110a5ac4209aa490 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Sat, 26 Oct 2024 14:19:55 -0700 Subject: [PATCH 5/6] fmt --- libs/core/tests/unit_tests/prompts/test_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index a628028a971bd..7b9dc835d7897 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -5,6 +5,7 @@ from typing import Any, Union, cast import pytest +from pydantic import ValidationError from syrupy import SnapshotAssertion from langchain_core._api.deprecation import ( @@ -33,7 +34,6 @@ _convert_to_message_template, ) from langchain_core.prompts.string import PromptTemplateFormat -from langchain_core.pydantic_v1 import ValidationError from langchain_core.utils.image import image_to_data_url from tests.unit_tests.pydantic_utils import _normalize_schema From 76b252011a6bdf9a72a97b30aa55be9276b6841d Mon Sep 17 00:00:00 2001 From: Bagatur Date: Sat, 26 Oct 2024 14:26:18 -0700 Subject: [PATCH 6/6] fmt --- libs/langchain/langchain/prompts/chat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index 445cba1e47e2b..faf7fbcb03924 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -12,9 +12,11 @@ MessagePromptTemplateT, MessagesPlaceholder, SystemMessagePromptTemplate, - _convert_to_message_template, _create_template_from_message_type, ) +from langchain_core.prompts.chat import ( + _convert_to_message_template as _convert_to_message, +) __all__ = [ "BaseMessagePromptTemplate", @@ -28,7 +30,7 @@ "ChatPromptTemplate", "ChatPromptValue", "ChatPromptValueConcrete", - "_convert_to_message_template", + "_convert_to_message", "_create_template_from_message_type", "MessagePromptTemplateT", "MessageLike",