diff --git a/docs/chat-prompting.md b/docs/chat-prompting.md index 080c5e08..552a03a1 100644 --- a/docs/chat-prompting.md +++ b/docs/chat-prompting.md @@ -1,11 +1,11 @@ # Chat Prompting -The `@chatprompt` decorator works just like `@prompt` but allows you to pass chat messages as a template rather than a single text prompt. This can be used to provide a system message or for few-shot prompting where you provide example responses to guide the model's output. Format fields denoted by curly braces `{example}` will be filled in all messages - use the `escape_braces` function to prevent a string being used as a template. +## @chatprompt + +The `@chatprompt` decorator works just like `@prompt` but allows you to pass chat messages as a template rather than a single text prompt. This can be used to provide a system message or for few-shot prompting where you provide example responses to guide the model's output. Format fields denoted by curly braces `{example}` will be filled in all messages (except `FunctionResultMessage`). ```python from magentic import chatprompt, AssistantMessage, SystemMessage, UserMessage -from magentic.chatprompt import escape_braces - from pydantic import BaseModel @@ -32,3 +32,94 @@ def get_movie_quote(movie: str) -> Quote: get_movie_quote("Iron Man") # Quote(quote='I am Iron Man.', character='Tony Stark') ``` + +### escape_braces + +To prevent curly braces from being interpreted as format fields, use the `escape_braces` function to escape them in strings. + +```python +from magentic.chatprompt import escape_braces + +string_with_braces = "Curly braces like {example} will be filled in!" +escaped_string = escape_braces(string_with_braces) +# 'Curly braces {{example}} will be filled in!' +escaped_string.format(example="test") +# 'Curly braces {example} will be filled in!' +``` + +## Placeholder + +The `Placeholder` class enables templating of `AssistantMessage` content within the `@chatprompt` decorator. This allows dynamic changing of the messages used to prompt the model based on the arguments provided when the function is called. + +```python +from magentic import chatprompt, AssistantMessage, Placeholder, UserMessage +from pydantic import BaseModel + + +class Quote(BaseModel): + quote: str + character: str + + +@chatprompt( + UserMessage("Tell me a quote from {movie}"), + AssistantMessage(Placeholder(Quote, "quote")), + UserMessage("What is a similar quote from the same movie?"), +) +def get_similar_quote(movie: str, quote: Quote) -> Quote: + ... + + +get_similar_quote( + movie="Star Wars", + quote=Quote(quote="I am your father", character="Darth Vader"), +) +# Quote(quote='The Force will be with you, always.', character='Obi-Wan Kenobi') +``` + +`Placeholder` can also be utilized in the `format` method of custom `Message` subclasses to provide an explicit way of inserting values from the function arguments. For example, see `UserImageMessage` in (TODO: link to GPT-vision page). + +## FunctionCall + +The content of an `AssistantMessage` can be a `FunctionCall`. This can be used to demonstrate to the LLM when/how it should call a function. + +```python +from magentic import ( + chatprompt, + AssistantMessage, + FunctionCall, + UserMessage, + SystemMessage, +) + + +def change_music_volume(increment: int): + """Change music volume level. Min 1, max 10.""" + print(f"Music volume change: {increment}") + + +def order_food(food: str, amount: int): + """Order food.""" + print(f"Ordered {amount} {food}") + + +@chatprompt( + SystemMessage( + "You are hosting a party and must keep the guests happy." + "Call functions as needed. Do not respond directly." + ), + UserMessage("It's pretty loud in here!"), + AssistantMessage(FunctionCall(change_music_volume, -2)), + UserMessage("{request}"), + functions=[change_music_volume, order_food], +) +def adjust_for_guest(request: str) -> FunctionCall[None]: + ... + + +func = adjust_for_guest("Do you have any more food?") +func() +# Ordered 3 pizza +``` + +To include the result of calling the function in the messages use a `FunctionResultMessage`. diff --git a/docs/index.md b/docs/index.md index 89706056..a4c15005 100644 --- a/docs/index.md +++ b/docs/index.md @@ -26,6 +26,8 @@ Configure your OpenAI API key by setting the `OPENAI_API_KEY` environment variab ## Usage +### @prompt + The `@prompt` decorator allows you to define a template for a Large Language Model (LLM) prompt as a Python function. When this function is called, the arguments are inserted into the template, then this prompt is sent to an LLM which generates the function output. ```python @@ -64,6 +66,8 @@ create_superhero("Garden Man") # Superhero(name='Garden Man', age=30, power='Control over plants', enemies=['Pollution Man', 'Concrete Woman']) ``` +### FunctionCall + An LLM can also decide to call functions. In this case the `@prompt`-decorated function returns a `FunctionCall` object which can be called to execute the function using the arguments provided by the LLM. ```python @@ -91,6 +95,8 @@ output() # 'Preheating to 350 F with mode bake' ``` +### @prompt_chain + Sometimes the LLM requires making one or more function calls to generate a final answer. The `@prompt_chain` decorator will resolve `FunctionCall` objects automatically and pass the output back to the LLM to continue until the final answer is reached. In the following example, when `describe_weather` is called the LLM first calls the `get_current_weather` function, then uses the result of this to formulate its final answer which gets returned. diff --git a/mkdocs.yml b/mkdocs.yml index 52c1f577..6df79577 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -33,6 +33,7 @@ plugins: - mkdocs-jupyter: # ignore_h1_titles: true execute: false + - search markdown_extensions: - tables diff --git a/pyproject.toml b/pyproject.toml index dc4690f8..06d210aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ description = "Seamlessly integrate LLMs as Python functions" license = "MIT" authors = ["Jack Collins"] readme = "README.md" +homepage = "https://magentic.dev/" repository = "https://github.com/jackmpcollins/magentic" [tool.poetry.dependencies] diff --git a/src/magentic/__init__.py b/src/magentic/__init__.py index d2048fcd..1d5fd02e 100644 --- a/src/magentic/__init__.py +++ b/src/magentic/__init__.py @@ -1,6 +1,7 @@ from magentic.chat_model.message import ( AssistantMessage, FunctionResultMessage, + Placeholder, SystemMessage, UserMessage, ) @@ -14,6 +15,7 @@ __all__ = [ "AssistantMessage", "FunctionResultMessage", + "Placeholder", "SystemMessage", "UserMessage", "OpenaiChatModel", diff --git a/src/magentic/chat_model/message.py b/src/magentic/chat_model/message.py index 40667fd9..ea418758 100644 --- a/src/magentic/chat_model/message.py +++ b/src/magentic/chat_model/message.py @@ -1,5 +1,38 @@ from abc import ABC, abstractmethod -from typing import Any, Awaitable, Callable, Generic, TypeVar, cast, overload +from typing import ( + Any, + Awaitable, + Callable, + Generic, + TypeVar, + cast, + get_origin, + overload, +) + +T = TypeVar("T") + + +class Placeholder(Generic[T]): + """A placeholder for a value in a message. + + When formatting a message, the placeholder is replaced with the value. + This is used in combination with the `@prompt`, `@promptchain`, and + `@chatprompt` decorators to enable inserting function arguments into + messages. + """ + + def __init__(self, type_: type[T], name: str): + self.type_ = type_ + self.name = name + + def format(self, **kwargs: Any) -> T: + value = kwargs[self.name] + if not isinstance(value, get_origin(self.type_) or self.type_): + msg = f"{self.name} must be of type {self.type_}" + raise TypeError(msg) + return cast(T, value) + ContentT = TypeVar("ContentT") @@ -44,10 +77,25 @@ def format(self, **kwargs: Any) -> "UserMessage": class AssistantMessage(Message[ContentT], Generic[ContentT]): """A message received from an LLM chat model.""" - def format(self, **kwargs: Any) -> "AssistantMessage[ContentT]": + @overload + def format( + self: "AssistantMessage[Placeholder[T]]", **kwargs: Any + ) -> "AssistantMessage[T]": + ... + + @overload + def format(self: "AssistantMessage[T]", **kwargs: Any) -> "AssistantMessage[T]": + ... + + def format( + self: "AssistantMessage[Placeholder[T]] | AssistantMessage[T]", **kwargs: Any + ) -> "AssistantMessage[T]": if isinstance(self.content, str): - content = cast(ContentT, self.content.format(**kwargs)) - return AssistantMessage(content) + formatted_content = cast(T, self.content.format(**kwargs)) + return AssistantMessage(formatted_content) + if isinstance(self.content, Placeholder): + content = cast(Placeholder[T], self.content) + return AssistantMessage(content.format(**kwargs)) return AssistantMessage(self.content) diff --git a/tests/chat_model/test_message.py b/tests/chat_model/test_message.py index 0dca9909..657af2ec 100644 --- a/tests/chat_model/test_message.py +++ b/tests/chat_model/test_message.py @@ -4,10 +4,21 @@ from magentic.chat_model.message import ( AssistantMessage, FunctionResultMessage, + Placeholder, UserMessage, ) +def test_placeholder(): + class Country(BaseModel): + name: str + + placeholder = Placeholder(Country, "country") + + assert_type(placeholder, Placeholder[Country]) + assert placeholder.name == "country" + + def test_user_message_format(): user_message = UserMessage("Hello {x}") user_message_formatted = user_message.format(x="world") @@ -26,12 +37,12 @@ def test_assistant_message_format_str(): assert assistant_message_formatted == AssistantMessage("Hello world") -def test_assistant_message_format(): +def test_assistant_message_format_placeholder(): class Country(BaseModel): name: str - assistant_message = AssistantMessage(Country(name="USA")) - assistant_message_formatted = assistant_message.format(foo="bar") + assistant_message = AssistantMessage(Placeholder(Country, "country")) + assistant_message_formatted = assistant_message.format(country=Country(name="USA")) assert_type(assistant_message_formatted, AssistantMessage[Country]) assert_type(assistant_message_formatted.content, Country) diff --git a/tests/test_chatprompt.py b/tests/test_chatprompt.py index f9418574..5a8cb47b 100644 --- a/tests/test_chatprompt.py +++ b/tests/test_chatprompt.py @@ -9,6 +9,7 @@ from magentic.chat_model.message import ( AssistantMessage, FunctionResultMessage, + Placeholder, SystemMessage, UserMessage, ) @@ -72,6 +73,19 @@ def func(param: str) -> str: assert func.format(param="arg") == expected_messages +def test_chatpromptfunction_format_with_placeholder(): + class Country(BaseModel): + name: str + + @chatprompt( + AssistantMessage(Placeholder(Country, "country")), + ) + def func(country: Country) -> str: + ... + + assert func.format(Country(name="USA")) == [AssistantMessage(Country(name="USA"))] + + def test_chatpromptfunction_call(): mock_model = Mock() mock_model.complete.return_value = AssistantMessage(content="Hello!")