Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[minor]: dict chat prompt template support #25674

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 24 additions & 81 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)

from langchain_core._api import deprecated
from langchain_core.load import Serializable
from langchain_core.messages import (
AIMessage,
AnyMessage,
Expand All @@ -36,87 +35,17 @@
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.message import (
BaseMessagePromptTemplate,
_DictMessagePromptTemplate,
)
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env


class BaseMessagePromptTemplate(Serializable, ABC):
"""Base class for message prompt templates."""

@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable.
Returns: True"""
return True

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]

@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.

Args:
**kwargs: Keyword arguments to use for formatting.

Returns:
List of BaseMessages.
"""

async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Async format messages from kwargs.
Should return a list of BaseMessages.

Args:
**kwargs: Keyword arguments to use for formatting.

Returns:
List of BaseMessages.
"""
return self.format_messages(**kwargs)

@property
@abstractmethod
def input_variables(self) -> List[str]:
"""Input variables for this prompt template.

Returns:
List of input variables.
"""

def pretty_repr(self, html: bool = False) -> str:
"""Human-readable representation.

Args:
html: Whether to format as HTML. Defaults to False.

Returns:
Human-readable representation.
"""
raise NotImplementedError

def pretty_print(self) -> None:
"""Print a human-readable representation."""
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201

def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.

Args:
other: Another prompt template.

Returns:
Combined prompt template.
"""
prompt = ChatPromptTemplate(messages=[self]) # type: ignore[call-arg]
return prompt + other


class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Prompt template that assumes variable is already list of messages.

Expand Down Expand Up @@ -810,6 +739,7 @@ def pretty_print(self) -> None:
Union[str, List[dict], List[object]],
],
str,
Dict[str, Any],
]


Expand Down Expand Up @@ -982,7 +912,8 @@ def __init__(

"""
_messages = [
_convert_to_message(message, template_format) for message in messages
_convert_to_message_template(message, template_format)
for message in messages
]

# Automatically infer input variables from messages
Expand Down Expand Up @@ -1109,7 +1040,7 @@ def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
return cls.from_messages([message])

@classmethod
@deprecated("0.0.1", alternative="from_messages classmethod", pending=True)
@deprecated("0.0.1", alternative="from_messages", pending=True)
def from_role_strings(
cls, string_messages: List[Tuple[str, str]]
) -> ChatPromptTemplate:
Expand All @@ -1129,7 +1060,7 @@ def from_role_strings(
)

@classmethod
@deprecated("0.0.1", alternative="from_messages classmethod", pending=True)
@deprecated("0.0.1", alternative="from_messages", pending=True)
def from_strings(
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate:
Expand Down Expand Up @@ -1279,15 +1210,17 @@ def append(self, message: MessageLikeRepresentation) -> None:
Args:
message: representation of a message to append.
"""
self.messages.append(_convert_to_message(message))
self.messages.append(_convert_to_message_template(message))

def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
"""Extend the chat template with a sequence of messages.

Args:
messages: sequence of message representations to append.
"""
self.messages.extend([_convert_to_message(message) for message in messages])
self.messages.extend(
[_convert_to_message_template(message) for message in messages]
)

@overload
def __getitem__(self, index: int) -> MessageLike: ...
Expand Down Expand Up @@ -1404,7 +1337,7 @@ def _create_template_from_message_type(
return message


def _convert_to_message(
def _convert_to_message_template(
message: MessageLikeRepresentation,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
Expand Down Expand Up @@ -1453,6 +1386,16 @@ def _convert_to_message(
cast(str, template), template_format=template_format
)
)
elif isinstance(message, dict):
if template_format == "jinja":
raise ValueError(
f"{template_format} is unsafe and is not supported for templates "
f"expressed as dicts. Please use 'f-string' or 'mustache' format."
)
_message = _DictMessagePromptTemplate(
template=message,
template_format=template_format, # type: ignore[arg-type]
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")

Expand Down
167 changes: 167 additions & 0 deletions libs/core/langchain_core/prompts/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Message prompt templates."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Literal

from langchain_core.load import Serializable
from langchain_core.messages import BaseMessage, convert_to_messages
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
get_template_variables,
)
from langchain_core.utils.image import image_to_data_url
from langchain_core.utils.interactive_env import is_interactive_env

if TYPE_CHECKING:
from langchain_core.prompts.chat import ChatPromptTemplate


class BaseMessagePromptTemplate(Serializable, ABC):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no change, just moved

"""Base class for message prompt templates."""

@classmethod
def is_lc_serializable(cls) -> bool:
"""Return True if the class is serializable, else False.

Returns:
True
"""
return True

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]

@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.

Args:
**kwargs: Keyword arguments to use for formatting.

Returns:
List of BaseMessages.
"""

async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Async format messages from kwargs.
Should return a list of BaseMessages.

Args:
**kwargs: Keyword arguments to use for formatting.

Returns:
List of BaseMessages.
"""
return self.format_messages(**kwargs)

@property
@abstractmethod
def input_variables(self) -> List[str]:
"""Input variables for this prompt template.

Returns:
List of input variables.
"""

def pretty_repr(self, html: bool = False) -> str:
"""Human-readable representation.

Args:
html: Whether to format as HTML. Defaults to False.

Returns:
Human-readable representation.
"""
raise NotImplementedError

def pretty_print(self) -> None:
"""Print a human-readable representation."""
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201

def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.

Args:
other: Another prompt template.

Returns:
Combined prompt template.
"""
from langchain_core.prompts.chat import ChatPromptTemplate

prompt = ChatPromptTemplate(messages=[self]) # type: ignore[call-arg]
return prompt + other


class _DictMessagePromptTemplate(BaseMessagePromptTemplate):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new type

"""Template represented by a dict that looks for input vars in all leaf vals.

Special handling of any dict value that contains
``{"type": "image_url", "image_url": {"path": "..."}}``
"""

template: Dict[str, Any]
template_format: Literal["f-string", "mustache"]

def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
msg_dict = _insert_input_variables(self.template, kwargs, self.template_format)
return convert_to_messages([msg_dict])

@property
def input_variables(self) -> List[str]:
return _get_input_variables(self.template, self.template_format)

@property
def _prompt_type(self) -> str:
return "message-dict-prompt"


def _get_input_variables(
template: dict, template_format: Literal["f-string", "mustache"]
) -> List[str]:
input_variables = []
for k, v in template.items():
if isinstance(v, str):
input_variables += get_template_variables(v, template_format)
elif isinstance(v, dict):
input_variables += _get_input_variables(v, template_format)
elif isinstance(v, (list, tuple)):
for x in v:
if isinstance(x, str):
input_variables += get_template_variables(x, template_format)
elif isinstance(x, dict):
input_variables += _get_input_variables(x, template_format)
return list(set(input_variables))


def _insert_input_variables(
template: Dict[str, Any],
inputs: Dict[str, Any],
template_format: Literal["f-string", "mustache"],
) -> Dict[str, Any]:
formatted = {}
formatter = DEFAULT_FORMATTER_MAPPING[template_format]
for k, v in template.items():
if isinstance(v, str):
formatted[k] = formatter(v, **inputs)
elif isinstance(v, dict):
# Special handling for loading local images.
if k == "image_url" and "path" in v:
formatted_path = formatter(v.pop("path"), **inputs)
v["url"] = image_to_data_url(formatted_path)
formatted[k] = _insert_input_variables(v, inputs, template_format)
elif isinstance(v, (list, tuple)):
formatted_v = []
for x in v:
if isinstance(x, str):
formatted_v.append(formatter(x, **inputs))
elif isinstance(x, dict):
formatted_v.append(
_insert_input_variables(x, inputs, template_format)
)
formatted[k] = type(v)(formatted_v)
return formatted
4 changes: 2 additions & 2 deletions libs/core/langchain_core/prompts/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ChatPromptTemplate,
MessageLikeRepresentation,
MessagesPlaceholder,
_convert_to_message,
_convert_to_message_template,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import (
Expand Down Expand Up @@ -86,7 +86,7 @@ class OutputSchema(BaseModel):
Returns:
a structured prompt template
"""
_messages = [_convert_to_message(message) for message in messages]
_messages = [_convert_to_message_template(message) for message in messages]

# Automatically infer input variables from messages
input_vars: Set[str] = set()
Expand Down
Binary file added libs/core/tests/unit_tests/prompts/favicon-16x16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading