Skip to content

Commit

Permalink
Add Placeholder class for templating (#118)
Browse files Browse the repository at this point in the history
* Add Placeholder class for templating

* Support Placeholder in AssistantMessage

* Replace AssistantMessage str type hints with cast

* Add homepage to pyproject.toml

* Add docs for Placeholder

* Make Placeholder importable from top level

* Add section on escape_braces

* Add FunctionCall section to chat prompting docs

* Add sections to chatprompt page

* Add headings to index page

* Add search plugin in mkdocs config
  • Loading branch information
jackmpcollins authored Feb 28, 2024
1 parent ea67293 commit 914bd46
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 10 deletions.
97 changes: 94 additions & 3 deletions docs/chat-prompting.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Chat Prompting

The `@chatprompt` decorator works just like `@prompt` but allows you to pass chat messages as a template rather than a single text prompt. This can be used to provide a system message or for few-shot prompting where you provide example responses to guide the model's output. Format fields denoted by curly braces `{example}` will be filled in all messages - use the `escape_braces` function to prevent a string being used as a template.
## @chatprompt

The `@chatprompt` decorator works just like `@prompt` but allows you to pass chat messages as a template rather than a single text prompt. This can be used to provide a system message or for few-shot prompting where you provide example responses to guide the model's output. Format fields denoted by curly braces `{example}` will be filled in all messages (except `FunctionResultMessage`).

```python
from magentic import chatprompt, AssistantMessage, SystemMessage, UserMessage
from magentic.chatprompt import escape_braces

from pydantic import BaseModel


Expand All @@ -32,3 +32,94 @@ def get_movie_quote(movie: str) -> Quote:
get_movie_quote("Iron Man")
# Quote(quote='I am Iron Man.', character='Tony Stark')
```

### escape_braces

To prevent curly braces from being interpreted as format fields, use the `escape_braces` function to escape them in strings.

```python
from magentic.chatprompt import escape_braces

string_with_braces = "Curly braces like {example} will be filled in!"
escaped_string = escape_braces(string_with_braces)
# 'Curly braces {{example}} will be filled in!'
escaped_string.format(example="test")
# 'Curly braces {example} will be filled in!'
```

## Placeholder

The `Placeholder` class enables templating of `AssistantMessage` content within the `@chatprompt` decorator. This allows dynamic changing of the messages used to prompt the model based on the arguments provided when the function is called.

```python
from magentic import chatprompt, AssistantMessage, Placeholder, UserMessage
from pydantic import BaseModel


class Quote(BaseModel):
quote: str
character: str


@chatprompt(
UserMessage("Tell me a quote from {movie}"),
AssistantMessage(Placeholder(Quote, "quote")),
UserMessage("What is a similar quote from the same movie?"),
)
def get_similar_quote(movie: str, quote: Quote) -> Quote:
...


get_similar_quote(
movie="Star Wars",
quote=Quote(quote="I am your father", character="Darth Vader"),
)
# Quote(quote='The Force will be with you, always.', character='Obi-Wan Kenobi')
```

`Placeholder` can also be utilized in the `format` method of custom `Message` subclasses to provide an explicit way of inserting values from the function arguments. For example, see `UserImageMessage` in (TODO: link to GPT-vision page).

## FunctionCall

The content of an `AssistantMessage` can be a `FunctionCall`. This can be used to demonstrate to the LLM when/how it should call a function.

```python
from magentic import (
chatprompt,
AssistantMessage,
FunctionCall,
UserMessage,
SystemMessage,
)


def change_music_volume(increment: int):
"""Change music volume level. Min 1, max 10."""
print(f"Music volume change: {increment}")


def order_food(food: str, amount: int):
"""Order food."""
print(f"Ordered {amount} {food}")


@chatprompt(
SystemMessage(
"You are hosting a party and must keep the guests happy."
"Call functions as needed. Do not respond directly."
),
UserMessage("It's pretty loud in here!"),
AssistantMessage(FunctionCall(change_music_volume, -2)),
UserMessage("{request}"),
functions=[change_music_volume, order_food],
)
def adjust_for_guest(request: str) -> FunctionCall[None]:
...


func = adjust_for_guest("Do you have any more food?")
func()
# Ordered 3 pizza
```

To include the result of calling the function in the messages use a `FunctionResultMessage`.
6 changes: 6 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ Configure your OpenAI API key by setting the `OPENAI_API_KEY` environment variab

## Usage

### @prompt

The `@prompt` decorator allows you to define a template for a Large Language Model (LLM) prompt as a Python function. When this function is called, the arguments are inserted into the template, then this prompt is sent to an LLM which generates the function output.

```python
Expand Down Expand Up @@ -64,6 +66,8 @@ create_superhero("Garden Man")
# Superhero(name='Garden Man', age=30, power='Control over plants', enemies=['Pollution Man', 'Concrete Woman'])
```

### FunctionCall

An LLM can also decide to call functions. In this case the `@prompt`-decorated function returns a `FunctionCall` object which can be called to execute the function using the arguments provided by the LLM.

```python
Expand Down Expand Up @@ -91,6 +95,8 @@ output()
# 'Preheating to 350 F with mode bake'
```

### @prompt_chain

Sometimes the LLM requires making one or more function calls to generate a final answer. The `@prompt_chain` decorator will resolve `FunctionCall` objects automatically and pass the output back to the LLM to continue until the final answer is reached.

In the following example, when `describe_weather` is called the LLM first calls the `get_current_weather` function, then uses the result of this to formulate its final answer which gets returned.
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ plugins:
- mkdocs-jupyter:
# ignore_h1_titles: true
execute: false
- search

markdown_extensions:
- tables
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ description = "Seamlessly integrate LLMs as Python functions"
license = "MIT"
authors = ["Jack Collins"]
readme = "README.md"
homepage = "https://magentic.dev/"
repository = "https://github.com/jackmpcollins/magentic"

[tool.poetry.dependencies]
Expand Down
2 changes: 2 additions & 0 deletions src/magentic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from magentic.chat_model.message import (
AssistantMessage,
FunctionResultMessage,
Placeholder,
SystemMessage,
UserMessage,
)
Expand All @@ -14,6 +15,7 @@
__all__ = [
"AssistantMessage",
"FunctionResultMessage",
"Placeholder",
"SystemMessage",
"UserMessage",
"OpenaiChatModel",
Expand Down
56 changes: 52 additions & 4 deletions src/magentic/chat_model/message.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,38 @@
from abc import ABC, abstractmethod
from typing import Any, Awaitable, Callable, Generic, TypeVar, cast, overload
from typing import (
Any,
Awaitable,
Callable,
Generic,
TypeVar,
cast,
get_origin,
overload,
)

T = TypeVar("T")


class Placeholder(Generic[T]):
"""A placeholder for a value in a message.
When formatting a message, the placeholder is replaced with the value.
This is used in combination with the `@prompt`, `@promptchain`, and
`@chatprompt` decorators to enable inserting function arguments into
messages.
"""

def __init__(self, type_: type[T], name: str):
self.type_ = type_
self.name = name

def format(self, **kwargs: Any) -> T:
value = kwargs[self.name]
if not isinstance(value, get_origin(self.type_) or self.type_):
msg = f"{self.name} must be of type {self.type_}"
raise TypeError(msg)
return cast(T, value)


ContentT = TypeVar("ContentT")

Expand Down Expand Up @@ -44,10 +77,25 @@ def format(self, **kwargs: Any) -> "UserMessage":
class AssistantMessage(Message[ContentT], Generic[ContentT]):
"""A message received from an LLM chat model."""

def format(self, **kwargs: Any) -> "AssistantMessage[ContentT]":
@overload
def format(
self: "AssistantMessage[Placeholder[T]]", **kwargs: Any
) -> "AssistantMessage[T]":
...

@overload
def format(self: "AssistantMessage[T]", **kwargs: Any) -> "AssistantMessage[T]":
...

def format(
self: "AssistantMessage[Placeholder[T]] | AssistantMessage[T]", **kwargs: Any
) -> "AssistantMessage[T]":
if isinstance(self.content, str):
content = cast(ContentT, self.content.format(**kwargs))
return AssistantMessage(content)
formatted_content = cast(T, self.content.format(**kwargs))
return AssistantMessage(formatted_content)
if isinstance(self.content, Placeholder):
content = cast(Placeholder[T], self.content)
return AssistantMessage(content.format(**kwargs))
return AssistantMessage(self.content)


Expand Down
17 changes: 14 additions & 3 deletions tests/chat_model/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,21 @@
from magentic.chat_model.message import (
AssistantMessage,
FunctionResultMessage,
Placeholder,
UserMessage,
)


def test_placeholder():
class Country(BaseModel):
name: str

placeholder = Placeholder(Country, "country")

assert_type(placeholder, Placeholder[Country])
assert placeholder.name == "country"


def test_user_message_format():
user_message = UserMessage("Hello {x}")
user_message_formatted = user_message.format(x="world")
Expand All @@ -26,12 +37,12 @@ def test_assistant_message_format_str():
assert assistant_message_formatted == AssistantMessage("Hello world")


def test_assistant_message_format():
def test_assistant_message_format_placeholder():
class Country(BaseModel):
name: str

assistant_message = AssistantMessage(Country(name="USA"))
assistant_message_formatted = assistant_message.format(foo="bar")
assistant_message = AssistantMessage(Placeholder(Country, "country"))
assistant_message_formatted = assistant_message.format(country=Country(name="USA"))

assert_type(assistant_message_formatted, AssistantMessage[Country])
assert_type(assistant_message_formatted.content, Country)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_chatprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from magentic.chat_model.message import (
AssistantMessage,
FunctionResultMessage,
Placeholder,
SystemMessage,
UserMessage,
)
Expand Down Expand Up @@ -72,6 +73,19 @@ def func(param: str) -> str:
assert func.format(param="arg") == expected_messages


def test_chatpromptfunction_format_with_placeholder():
class Country(BaseModel):
name: str

@chatprompt(
AssistantMessage(Placeholder(Country, "country")),
)
def func(country: Country) -> str:
...

assert func.format(Country(name="USA")) == [AssistantMessage(Country(name="USA"))]


def test_chatpromptfunction_call():
mock_model = Mock()
mock_model.complete.return_value = AssistantMessage(content="Hello!")
Expand Down

0 comments on commit 914bd46

Please sign in to comment.