Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for the new ChatMessage data class in ChatPromptBuilder #141

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/pydoc/config/builders_api.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../]
modules: ["haystack_experimental.components.builders.chat_prompt_builder"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
expression:
documented_only: true
do_not_filter_modules: false
skip_empty_modules: true
- type: smart
- type: crossref
renderer:
type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer
excerpt: Extract the output of a Generator to an Answer format, and build prompts.
category_slug: experiments-api
title: Builders
slug: experimental-builders-api
order: 160
markdown:
descriptive_class_title: false
classdef_code_block: false
descriptive_module_title: true
add_method_class_prefix: true
add_member_class_prefix: false
filename: experimental_builders_api.md
7 changes: 7 additions & 0 deletions haystack_experimental/components/builders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from .chat_prompt_builder import ChatPromptBuilder

__all__ = ["ChatPromptBuilder"]
276 changes: 276 additions & 0 deletions haystack_experimental/components/builders/chat_prompt_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from typing import Any, Dict, List, Literal, Optional, Set, Union

from haystack import component, default_from_dict, default_to_dict, logging
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment

from haystack_experimental.dataclasses.chat_message import (
ChatMessage,
ChatRole,
TextContent,
)

logger = logging.getLogger(__name__)


@component
class ChatPromptBuilder:
"""
Renders a chat prompt from a template string using Jinja2 syntax.

It constructs prompts using static or dynamic templates, which you can update for each pipeline run.

Template variables in the template are optional unless specified otherwise.
If an optional variable isn't provided, it defaults to an empty string. Use `variable` and `required_variables`
to define input types and required variables.

### Usage examples

#### With static prompt template

```python
template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
builder = ChatPromptBuilder(template=template)
builder.run(target_language="spanish", snippet="I can't speak spanish.")
```

#### Overriding static template at runtime

```python
template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
builder = ChatPromptBuilder(template=template)
builder.run(target_language="spanish", snippet="I can't speak spanish.")

msg = "Translate to {{ target_language }} and summarize. Context: {{ snippet }}; Summary:"
summary_template = [ChatMessage.from_user(msg)]
builder.run(target_language="spanish", snippet="I can't speak spanish.", template=summary_template)
```

#### With dynamic prompt template

```python
from haystack.components.builders import ChatPromptBuilder
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage
from haystack import Pipeline
from haystack.utils import Secret

# no parameter init, we don't use any runtime template variables
prompt_builder = ChatPromptBuilder()
llm = OpenAIChatGenerator(api_key=Secret.from_token("<your-api-key>"), model="gpt-4o-mini")

pipe = Pipeline()
pipe.add_component("prompt_builder", prompt_builder)
pipe.add_component("llm", llm)
pipe.connect("prompt_builder.prompt", "llm.messages")

location = "Berlin"
language = "English"
system_message = ChatMessage.from_system("You are an assistant giving information to tourists in {{language}}")
messages = [system_message, ChatMessage.from_user("Tell me about {{location}}")]

res = pipe.run(data={"prompt_builder": {"template_variables": {"location": location, "language": language},
"template": messages}})
print(res)

>> {'llm': {'replies': [ChatMessage(content="Berlin is the capital city of Germany and one of the most vibrant
and diverse cities in Europe. Here are some key things to know...Enjoy your time exploring the vibrant and dynamic
capital of Germany!", role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4o-mini',
'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 27, 'completion_tokens': 681, 'total_tokens':
708}})]}}


messages = [system_message, ChatMessage.from_user("What's the weather forecast for {{location}} in the next
{{day_count}} days?")]

res = pipe.run(data={"prompt_builder": {"template_variables": {"location": location, "day_count": "5"},
"template": messages}})

print(res)
>> {'llm': {'replies': [ChatMessage(content="Here is the weather forecast for Berlin in the next 5
days:\\n\\nDay 1: Mostly cloudy with a high of 22°C (72°F) and...so it's always a good idea to check for updates
closer to your visit.", role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4o-mini',
'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 37, 'completion_tokens': 201,
'total_tokens': 238}})]}}
```

"""

def __init__(
self,
template: Optional[List[ChatMessage]] = None,
required_variables: Optional[Union[List[str], Literal["*"]]] = None,
variables: Optional[List[str]] = None,
):
"""
Constructs a ChatPromptBuilder component.

:param template:
A list of `ChatMessage` objects. The component looks for Jinja2 template syntax and
renders the prompt with the provided variables. Provide the template in either
the `init` method` or the `run` method.
:param required_variables:
List variables that must be provided as input to ChatPromptBuilder.
If a variable listed as required is not provided, an exception is raised.
If set to "*", all variables found in the prompt are required.
:param variables:
List input variables to use in prompt templates instead of the ones inferred from the
`template` parameter. For example, to use more variables during prompt engineering than the ones present
in the default template, you can provide them here.
"""
self._variables = variables
self._required_variables = required_variables
self.required_variables = required_variables or []
self.template = template
variables = variables or []
self._env = SandboxedEnvironment()
if template and not variables:
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
# infer variables from template
if message.text is None:
raise ValueError(
f"The {self.__class__.__name__} requires a non-empty list of ChatMessage"
" instances with text content."
)
ast = self._env.parse(message.text)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)
self.variables = variables

# setup inputs
for var in self.variables:
if self.required_variables == "*" or var in self.required_variables:
component.set_input_type(self, var, Any)
else:
component.set_input_type(self, var, Any, "")

@component.output_types(prompt=List[ChatMessage])
def run(
self,
template: Optional[List[ChatMessage]] = None,
template_variables: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Renders the prompt template with the provided variables.

It applies the template variables to render the final prompt. You can provide variables with pipeline kwargs.
To overwrite the default template, you can set the `template` parameter.
To overwrite pipeline kwargs, you can set the `template_variables` parameter.

:param template:
An optional list of `ChatMessage` objects to overwrite ChatPromptBuilder's default template.
If `None`, the default template provided at initialization is used.
:param template_variables:
An optional dictionary of template variables to overwrite the pipeline variables.
:param kwargs:
Pipeline variables used for rendering the prompt.

:returns: A dictionary with the following keys:
- `prompt`: The updated list of `ChatMessage` objects after rendering the templates.
:raises ValueError:
If `chat_messages` is empty or contains elements that are not instances of `ChatMessage`.
"""
kwargs = kwargs or {}
template_variables = template_variables or {}
template_variables_combined = {**kwargs, **template_variables}

if template is None:
template = self.template

if not template:
raise ValueError(
f"The {self.__class__.__name__} requires a non-empty list of ChatMessage instances. "
f"Please provide a valid list of ChatMessage instances to render the prompt."
)

if not all(isinstance(message, ChatMessage) and message.text is not None for message in template):
raise ValueError(
f"The {self.__class__.__name__} expects a list containing only ChatMessage instances "
f"with text content. The provided list contains other types. Please ensure that all "
"elements in the list are ChatMessage instances."
)

processed_messages = []
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
self._validate_variables(set(template_variables_combined.keys()))

assert message.text is not None
compiled_template = self._env.from_string(message.text)
rendered_content = compiled_template.render(template_variables_combined)

rendered_message = ChatMessage(
message.role,
[TextContent(rendered_content)],
deepcopy(message.meta),
)
processed_messages.append(rendered_message)
else:
processed_messages.append(message)

return {"prompt": processed_messages}

def _validate_variables(self, provided_variables: Set[str]):
"""
Checks if all the required template variables are provided.

:param provided_variables:
A set of provided template variables.
:raises ValueError:
If no template is provided or if all the required template variables are not provided.
"""
if self.required_variables == "*":
required_variables = sorted(self.variables)
else:
required_variables = self.required_variables
missing_variables = [var for var in required_variables if var not in provided_variables]
if missing_variables:
missing_vars_str = ", ".join(missing_variables)
raise ValueError(
f"Missing required input variables in ChatPromptBuilder: {missing_vars_str}. "
f"Required variables: {required_variables}. Provided variables: {provided_variables}."
)

def to_dict(self) -> Dict[str, Any]:
"""
Returns a dictionary representation of the component.

:returns:
Serialized dictionary representation of the component.
"""
if self.template is not None:
template = [m.to_dict() for m in self.template]
else:
template = None

return default_to_dict(
self,
template=template,
variables=self._variables,
required_variables=self._required_variables,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChatPromptBuilder":
"""
Deserialize this component from a dictionary.

:param data:
The dictionary to deserialize and create the component.

:returns:
The deserialized component.
"""
init_parameters = data["init_parameters"]
template = init_parameters.get("template")
if template:
init_parameters["template"] = [ChatMessage.from_dict(d) for d in template]

return default_from_dict(cls, data)
33 changes: 27 additions & 6 deletions haystack_experimental/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,24 +158,34 @@ def is_from(self, role: ChatRole) -> bool:
return self._role == role

@classmethod
def from_user(cls, text: str) -> "ChatMessage":
def from_user(
cls,
text: str,
meta: Optional[Dict[str, Any]] = None,
) -> "ChatMessage":
"""
Create a message from the user.

:param text: The text content of the message.
:param meta: Additional metadata associated with the message.
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
:returns: A new ChatMessage instance.
"""
return cls(_role=ChatRole.USER, _content=[TextContent(text=text)])
return cls(_role=ChatRole.USER, _content=[TextContent(text=text)], _meta=meta or {})

@classmethod
def from_system(cls, text: str) -> "ChatMessage":
def from_system(
cls,
text: str,
meta: Optional[Dict[str, Any]] = None,
) -> "ChatMessage":
"""
Create a message from the system.

:param text: The text content of the message.
:param meta: Additional metadata associated with the message.
:returns: A new ChatMessage instance.
"""
return cls(_role=ChatRole.SYSTEM, _content=[TextContent(text=text)])
return cls(_role=ChatRole.SYSTEM, _content=[TextContent(text=text)], _meta=meta or {})

@classmethod
def from_assistant(
Expand All @@ -201,16 +211,27 @@ def from_assistant(
return cls(_role=ChatRole.ASSISTANT, _content=content, _meta=meta or {})

@classmethod
def from_tool(cls, tool_result: str, origin: ToolCall, error: bool = False) -> "ChatMessage":
def from_tool(
cls,
tool_result: str,
origin: ToolCall,
error: bool = False,
meta: Optional[Dict[str, Any]] = None,
) -> "ChatMessage":
"""
Create a message from a Tool.

:param tool_result: The result of the Tool invocation.
:param origin: The Tool call that produced this result.
:param error: Whether the Tool invocation resulted in an error.
:param meta: Additional metadata associated with the message.
:returns: A new ChatMessage instance.
"""
return cls(_role=ChatRole.TOOL, _content=[ToolCallResult(result=tool_result, origin=origin, error=error)])
return cls(
_role=ChatRole.TOOL,
_content=[ToolCallResult(result=tool_result, origin=origin, error=error)],
_meta=meta or {},
)

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down
3 changes: 3 additions & 0 deletions test/components/builders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
Loading
Loading