From da537a6dee9919758e6ade36a617e7c6b3841c0c Mon Sep 17 00:00:00 2001 From: Ivan Leo Date: Tue, 29 Oct 2024 15:03:29 +0800 Subject: [PATCH 1/2] Patching Anthropic System (#1130) Co-authored-by: Richie Caputo Co-authored-by: Richie Caputo <43445060+arcaputo3@users.noreply.github.com> --- .github/workflows/ruff.yml | 8 +- docs/hooks/hide_lines.py | 4 +- .../7-synthetic-data-generation.ipynb | 2 +- instructor/client.py | 7 +- instructor/process_response.py | 55 ++--- instructor/utils.py | 53 ++++- make_desc.py | 14 +- tests/llm/test_anthropic/evals/test_simple.py | 17 +- tests/llm/test_anthropic/test_system.py | 144 +++++++++++++ tests/test_utils.py | 201 ++++++++++++++++++ 10 files changed, 443 insertions(+), 62 deletions(-) create mode 100644 tests/llm/test_anthropic/test_system.py diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index eaecfbe65..2c81358ca 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -31,13 +31,7 @@ jobs: python3 -m pip install -r requirements.txt python3 -m pip install -r requirements-doc.txt - name: Run Continuous Integration Action - run: | - set -e -o pipefail - export CUSTOM_PACKAGES="${{ env.CUSTOM_PACKAGES }}" && - export CUSTOM_FLAGS="${{ env.CUSTOM_FLAGS }}" && - curl -sSL https://raw.githubusercontent.com/gao-hongnan/omniverse/2fd5de1b8103e955cd5f022ab016b72fa901fa8f/scripts/devops/continuous-integration/lint_ruff.sh -o lint_ruff.sh - chmod +x lint_ruff.sh - bash lint_ruff.sh | tee ${{ env.WORKING_DIRECTORY }}/${{ env.RUFF_OUTPUT_FILENAME }} + uses: astral-sh/ruff-action@v1 - name: Upload Artifacts uses: actions/upload-artifact@v3 with: diff --git a/docs/hooks/hide_lines.py b/docs/hooks/hide_lines.py index 317f00605..7e346903f 100644 --- a/docs/hooks/hide_lines.py +++ b/docs/hooks/hide_lines.py @@ -5,9 +5,9 @@ @mkdocs.plugins.event_priority(0) # pylint: disable=unused-argument -def on_startup(command: str, dirty: bool) -> None: +def on_startup(command: str, dirty: bool) -> None: # noqa: ARG001 """Monkey patch Highlight extension to hide lines in code blocks.""" - original = highlight.Highlight.highlight + original = highlight.Highlight.highlight # type: ignore def patched(self: Any, src: str, *args: Any, **kwargs: Any) -> Any: lines = src.splitlines(keepends=True) diff --git a/docs/tutorials/7-synthetic-data-generation.ipynb b/docs/tutorials/7-synthetic-data-generation.ipynb index 74bbc0607..3bbdfce1f 100644 --- a/docs/tutorials/7-synthetic-data-generation.ipynb +++ b/docs/tutorials/7-synthetic-data-generation.ipynb @@ -565,7 +565,7 @@ "\n", "for metric,size in product(METRICS,SIZES):\n", " metric_name, score_fn = metric\n", - " score_fns[f\"{metric_name}@{size}\"] = lambda predictions,labels : score_fn(predictions[:size],labels)" + " score_fns[f\"{metric_name}@{size}\"] = lambda predictions,labels, fn=score_fn, k=size: fn(predictions[:k],labels) # type: ignore" ] }, { diff --git a/instructor/client.py b/instructor/client.py index fa452c884..d50659112 100644 --- a/instructor/client.py +++ b/instructor/client.py @@ -507,12 +507,7 @@ def from_openai( instructor.Mode.MD_JSON, } - if provider in {Provider.DATABRICKS}: - assert mode in { - instructor.Mode.MD_JSON - }, "Databricks provider only supports `MD_JSON` mode." - - if provider in {Provider.OPENAI}: + if provider in {Provider.OPENAI, Provider.DATABRICKS}: assert mode in { instructor.Mode.TOOLS, instructor.Mode.JSON, diff --git a/instructor/process_response.py b/instructor/process_response.py index 3dd876f34..277151711 100644 --- a/instructor/process_response.py +++ b/instructor/process_response.py @@ -20,7 +20,11 @@ from instructor.dsl.partial import PartialBase from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type from instructor.function_calls import OpenAISchema, openai_schema -from instructor.utils import merge_consecutive_messages +from instructor.utils import ( + merge_consecutive_messages, + extract_system_messages, + combine_system_messages, +) from instructor.multimodal import convert_messages logger = logging.getLogger("instructor") @@ -332,20 +336,15 @@ def handle_anthropic_tools( "name": response_model.__name__, } - system_messages = [ - m["content"] for m in new_kwargs["messages"] if m["role"] == "system" - ] + system_messages = extract_system_messages(new_kwargs.get("messages", [])) - if "system" in new_kwargs and system_messages: - raise ValueError( - "Only a single system message is supported - either set it as a message in the messages array or use the system parameter" + if system_messages: + new_kwargs["system"] = combine_system_messages( + new_kwargs.get("system"), system_messages ) - if "system" not in new_kwargs: - new_kwargs["system"] = "\n\n".join(system_messages) - new_kwargs["messages"] = [ - m for m in new_kwargs["messages"] if m["role"] != "system" + m for m in new_kwargs.get("messages", []) if m["role"] != "system" ] return response_model, new_kwargs @@ -354,25 +353,18 @@ def handle_anthropic_tools( def handle_anthropic_json( response_model: type[T], new_kwargs: dict[str, Any] ) -> tuple[type[T], dict[str, Any]]: - openai_system_messages = [ - message["content"] - for message in new_kwargs.get("messages", []) - if message["role"] == "system" - ] + system_messages = extract_system_messages(new_kwargs.get("messages", [])) - if "system" in new_kwargs and openai_system_messages: - raise ValueError( - "Only a single System message is supported - either set it using the system parameter or in the list of messages" + if system_messages: + new_kwargs["system"] = combine_system_messages( + new_kwargs.get("system"), system_messages ) - if not "system" in new_kwargs: - new_kwargs["system"] = "\n\n".join(openai_system_messages) - new_kwargs["messages"] = [ - m for m in new_kwargs["messages"] if m["role"] != "system" + m for m in new_kwargs.get("messages", []) if m["role"] != "system" ] - message = dedent( + json_schema_message = dedent( f""" As a genius expert, your task is to understand the content and provide the parsed objects in json that match the following json_schema:\n @@ -383,7 +375,9 @@ def handle_anthropic_json( """ ) - new_kwargs["system"] = f"{new_kwargs.get('system', '')}\n\n{message}".strip() + new_kwargs["system"] = combine_system_messages( + new_kwargs.get("system"), [{"type": "text", "text": json_schema_message}] + ) return response_model, new_kwargs @@ -664,7 +658,7 @@ def handle_response_model( # This is cause cohere uses 'message' and 'chat_history' instead of 'messages' return handle_cohere_modes(new_kwargs) # Handle images without a response model - if autodetect_images and "messages" in new_kwargs: + if "messages" in new_kwargs: messages = convert_messages( new_kwargs["messages"], mode, @@ -672,14 +666,11 @@ def handle_response_model( ) if mode in {Mode.ANTHROPIC_JSON, Mode.ANTHROPIC_TOOLS}: # Handle OpenAI style or Anthropic style messages - new_kwargs["messages"] = [ - m for m in messages if m["role"] != "system" - ] + new_kwargs["messages"] = [m for m in messages if m["role"] != "system"] if "system" not in new_kwargs: - system_messages = (m for m in messages if m["role"] == "system") - system_message = next(system_messages, None) + system_message = extract_system_messages(messages) if system_message: - new_kwargs["system"] = system_message["content"] + new_kwargs["system"] = system_message else: new_kwargs["messages"] = messages return None, new_kwargs diff --git a/instructor/utils.py b/instructor/utils.py index e9cacac25..6bee2b1c6 100644 --- a/instructor/utils.py +++ b/instructor/utils.py @@ -10,7 +10,10 @@ Callable, Generic, Protocol, + Union, + TypedDict, TypeVar, + cast ) from pydantic import BaseModel import os @@ -131,7 +134,6 @@ def update_total_usage( response: T_Model | None, total_usage: OpenAIUsage | AnthropicUsage, ) -> T_Model | ChatCompletion | None: - if response is None: return None @@ -369,3 +371,52 @@ def update_gemini_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: def disable_pydantic_error_url(): os.environ["PYDANTIC_ERRORS_INCLUDE_URL"] = "0" + + +class SystemMessage(TypedDict, total=False): + type: str + text: str + cache_control: dict[str, str] + + +def combine_system_messages( + existing_system: Union[str, list[SystemMessage], None], # noqa: UP007 + new_system: Union[str, list[SystemMessage]], # noqa: UP007 +) -> Union[str, list[SystemMessage]]: # noqa: UP007 + if existing_system is None: + return new_system + + if isinstance(existing_system, str) and isinstance(new_system, str): + return f"{existing_system}\n\n{new_system}" + + if isinstance(existing_system, list) and isinstance(new_system, list): + return existing_system + new_system + + if isinstance(existing_system, str) and isinstance(new_system, list): + return [SystemMessage(type="text", text=existing_system)] + new_system + + if isinstance(existing_system, list) and isinstance(new_system, str): + return existing_system + [SystemMessage(type="text", text=new_system)] + + raise ValueError("Unsupported system message type combination") + + +def extract_system_messages(messages: list[dict[str, Any]]) -> list[SystemMessage]: + def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: # noqa: UP007 + if isinstance(content, str): + return SystemMessage(type="text", text=content) + elif isinstance(content, dict): + return SystemMessage(**content) + else: + raise ValueError(f"Unsupported content type: {type(content)}") + + result: list[SystemMessage] = [] + for m in messages: + if m["role"] == "system": + # System message must always be a string or list of dictionaries + content = cast(Union[str, list[dict[str, Any]]], m["content"]) # noqa: UP007 + if isinstance(content, list): + result.extend(convert_message(item) for item in content) + else: + result.append(convert_message(content)) + return result diff --git a/make_desc.py b/make_desc.py index 9a4a896c1..6dc9211aa 100644 --- a/make_desc.py +++ b/make_desc.py @@ -1,5 +1,5 @@ import os -from typing import Optional, List, Set, Literal +from typing import Optional, Literal import asyncio from openai import AsyncOpenAI import typer @@ -15,7 +15,7 @@ async def generate_ai_frontmatter( - client: AsyncOpenAI, title: str, content: str, categories: List[str] + client: AsyncOpenAI, title: str, content: str, categories: list[str] ): """ Generate a description and categories for the given content using AI. @@ -35,8 +35,8 @@ class DescriptionAndCategories(BaseModel): reasoning: str = Field( ..., description="The reasoning for the correct categories" ) - tags: List[str] - categories: List[ + tags: list[str] + categories: list[ Literal[ "OpenAI", "Anthropic", @@ -72,7 +72,7 @@ class DescriptionAndCategories(BaseModel): return response -def get_all_categories(root_dir: str) -> Set[str]: +def get_all_categories(root_dir: str) -> set[str]: """ Read all markdown files and extract unique categories. @@ -113,7 +113,7 @@ def preview_categories(root_dir: str) -> None: async def process_file( - client: AsyncOpenAI, file_path: str, categories: List[str], enable_comments: bool + client: AsyncOpenAI, file_path: str, categories: list[str], enable_comments: bool ) -> None: """ Process a single file, adding or updating the description and categories in the front matter. @@ -143,7 +143,7 @@ async def process_file( async def process_files( root_dir: str, - api_key: Optional[str] = None, + api_key: Optional[str] = None, # noqa: ARG001 use_categories: bool = False, enable_comments: bool = False, ) -> None: diff --git a/tests/llm/test_anthropic/evals/test_simple.py b/tests/llm/test_anthropic/evals/test_simple.py index 2ebdadfd9..62134c739 100644 --- a/tests/llm/test_anthropic/evals/test_simple.py +++ b/tests/llm/test_anthropic/evals/test_simple.py @@ -20,18 +20,19 @@ class User(BaseModel): @field_validator("name") def name_is_uppercase(cls, v: str): - assert v.isupper(), "Name must be uppercase, please fix" + assert v.isupper(), f"{v} is not an uppercased string. Note that all characters in {v} must be uppercase (EG. TIM SARAH ADAM)." return v resp = client.messages.create( model="claude-3-haiku-20240307", - max_tokens=1024, + max_tokens=4096, max_retries=2, + system="Make sure to follow the instructions carefully and return a response object that matches the json schema requested. Age is an integer.", messages=[ { "role": "user", "content": "Extract John is 18 years old.", - } + }, ], response_model=User, ) # type: ignore @@ -53,7 +54,7 @@ class User(BaseModel): resp = client.messages.create( model="claude-3-haiku-20240307", - max_tokens=1024, + max_tokens=4096, max_retries=0, messages=[ { @@ -83,6 +84,7 @@ class User(BaseModel): model="claude-3-haiku-20240307", max_tokens=1024, max_retries=0, + system="Make sure to follow the instructions carefully and return a response object that matches the json schema requested. Family members here is just asking for a list of names", messages=[ { "role": "user", @@ -132,7 +134,7 @@ class User(BaseModel): resp = client.messages.create( model="claude-3-haiku-20240307", - max_tokens=1024, + max_tokens=4096, max_retries=2, messages=[ { @@ -185,7 +187,10 @@ class User(BaseModel): max_tokens=1024, max_retries=0, messages=[ - {"role": "system", "content": "EVERYTHING MUST BE IN ALL CAPS"}, + { + "role": "system", + "content": "Please make sure to follow the instructions carefully and return a valid response object. All strings must be fully capitalised in all caps. (Eg. THIS IS AN UPPERCASE STRING) and age is an integer.", + }, { "role": "user", "content": "Create a user for a model with a name and age.", diff --git a/tests/llm/test_anthropic/test_system.py b/tests/llm/test_anthropic/test_system.py new file mode 100644 index 000000000..eb3158902 --- /dev/null +++ b/tests/llm/test_anthropic/test_system.py @@ -0,0 +1,144 @@ +import pytest +import instructor +from pydantic import BaseModel +from itertools import product +from .util import models, modes +from anthropic.types.message import Message + + +class User(BaseModel): + name: str + age: int + + +@pytest.mark.parametrize("model, mode", product(models, modes)) +def test_creation(model, mode, client): + client = instructor.from_anthropic(client, mode=mode) + response = client.chat.completions.create( + model=model, + response_model=User, + messages=[ + { + "role": "system", + "content": [ + {"type": "text", "text": "Mike is 37 years old"} + ], + }, + { + "role": "user", + "content": "Extract a user from the story.", + }, + ], + temperature=1, + max_tokens=1000, + ) + + # Assertions to validate the response + assert isinstance(response, User) + assert response.name == "Mike" + assert response.age == 37 + + +@pytest.mark.parametrize("model, mode", product(models, modes)) +def test_creation_with_system_cache(model, mode, client): + client = instructor.from_anthropic(client, mode=mode, enable_prompt_caching=True) + response, message = client.chat.completions.create_with_completion( + model=model, + response_model=User, + messages=[ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Mike is 37 years old " * 200 + "", + "cache_control": {"type": "ephemeral"}, + }, + { + "type": "text", + "text": "You are a helpful assistant who extracts users from stories.", + }, + ], + }, + { + "role": "user", + "content": "Extract a user from the story.", + }, + ], + temperature=1, + max_tokens=1000, + ) + + # Assertions to validate the response + assert isinstance(response, User) + assert response.name == "Mike" + assert response.age == 37 + + # Assert a cache write or cache hit + assert ( + message.usage.cache_creation_input_tokens > 0 + or message.usage.cache_read_input_tokens > 0 + ) + + +@pytest.mark.parametrize("model, mode", product(models, modes)) +def test_creation_with_system_cache_anthropic_style(model, mode, client): + client = instructor.from_anthropic(client, mode=mode, enable_prompt_caching=True) + response, message = client.chat.completions.create_with_completion( + model=model, + system=[ + { + "type": "text", + "text": "Mike is 37 years old " * 200 + "", + "cache_control": {"type": "ephemeral"}, + }, + { + "type": "text", + "text": "You are a helpful assistant who extracts users from stories.", + }, + ], + response_model=User, + messages=[ + { + "role": "user", + "content": "Extract a user from the story.", + }, + ], + temperature=1, + max_tokens=1000, + ) + + # Assertions to validate the response + assert isinstance(response, User) + assert response.name == "Mike" + assert response.age == 37 + + # Assert a cache write or cache hit + assert ( + message.usage.cache_creation_input_tokens > 0 + or message.usage.cache_read_input_tokens > 0 + ) + + +@pytest.mark.parametrize("model, mode", product(models, modes)) +def test_creation_no_response_model(model, mode, client): + client = instructor.from_anthropic(client, mode=mode) + response = client.chat.completions.create( + response_model=None, + model=model, + messages=[ + { + "role": "system", + "content": [{"type": "text", "text": "Mike is 37 years old"}], + }, + { + "role": "user", + "content": "Extract a user from the story.", + }, + ], + temperature=1, + max_tokens=1000, + ) + + # Assertions to validate the response + assert isinstance(response, Message) diff --git a/tests/test_utils.py b/tests/test_utils.py index 9dead016b..6fa7ff36f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,8 @@ extract_json_from_stream, extract_json_from_stream_async, merge_consecutive_messages, + extract_system_messages, + combine_system_messages, ) @@ -185,3 +187,202 @@ def my_property(cls): return cls.clvar assert MyClass.my_property == 1 + + +def test_combine_system_messages_string_string(): + existing = "Existing message" + new = "New message" + result = combine_system_messages(existing, new) + assert result == "Existing message\n\nNew message" + + +def test_combine_system_messages_list_list(): + existing = [{"type": "text", "text": "Existing"}] + new = [{"type": "text", "text": "New"}] + result = combine_system_messages(existing, new) + assert result == [ + {"type": "text", "text": "Existing"}, + {"type": "text", "text": "New"}, + ] + + +def test_combine_system_messages_string_list(): + existing = "Existing" + new = [{"type": "text", "text": "New"}] + result = combine_system_messages(existing, new) + assert result == [ + {"type": "text", "text": "Existing"}, + {"type": "text", "text": "New"}, + ] + + +def test_combine_system_messages_list_string(): + existing = [{"type": "text", "text": "Existing"}] + new = "New" + result = combine_system_messages(existing, new) + assert result == [ + {"type": "text", "text": "Existing"}, + {"type": "text", "text": "New"}, + ] + + +def test_combine_system_messages_none_string(): + existing = None + new = "New" + result = combine_system_messages(existing, new) + assert result == "New" + + +def test_combine_system_messages_none_list(): + existing = None + new = [{"type": "text", "text": "New"}] + result = combine_system_messages(existing, new) + assert result == [{"type": "text", "text": "New"}] + + +def test_combine_system_messages_invalid_type(): + with pytest.raises(ValueError): + combine_system_messages(123, "New") + + +def test_extract_system_messages(): + messages = [ + {"role": "system", "content": "System message 1"}, + {"role": "user", "content": "User message"}, + {"role": "system", "content": "System message 2"}, + ] + result = extract_system_messages(messages) + expected = [ + {"type": "text", "text": "System message 1"}, + {"type": "text", "text": "System message 2"}, + ] + assert result == expected + + +def test_extract_system_messages_no_system(): + messages = [ + {"role": "user", "content": "User message"}, + {"role": "assistant", "content": "Assistant message"}, + ] + result = extract_system_messages(messages) + assert result == [] + + +def test_combine_system_messages_with_cache_control(): + existing = [ + { + "type": "text", + "text": "You are an AI assistant.", + }, + { + "type": "text", + "text": "This is some context.", + "cache_control": {"type": "ephemeral"}, + }, + ] + new = "Provide insightful analysis." + result = combine_system_messages(existing, new) + expected = [ + { + "type": "text", + "text": "You are an AI assistant.", + }, + { + "type": "text", + "text": "This is some context.", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "Provide insightful analysis."}, + ] + assert result == expected + + +def test_combine_system_messages_string_to_cache_control(): + existing = "You are an AI assistant." + new = [ + { + "type": "text", + "text": "Analyze this text:", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": ""}, + ] + result = combine_system_messages(existing, new) + expected = [ + {"type": "text", "text": "You are an AI assistant."}, + { + "type": "text", + "text": "Analyze this text:", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": ""}, + ] + assert result == expected + + +def test_extract_system_messages_with_cache_control(): + messages = [ + {"role": "system", "content": "You are an AI assistant."}, + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Analyze this text:", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + {"role": "user", "content": "User message"}, + {"role": "system", "content": ""}, + ] + result = extract_system_messages(messages) + expected = [ + {"type": "text", "text": "You are an AI assistant."}, + { + "type": "text", + "text": "Analyze this text:", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": ""}, + ] + assert result == expected + + +def test_combine_system_messages_preserve_cache_control(): + existing = [ + { + "type": "text", + "text": "You are an AI assistant.", + }, + { + "type": "text", + "text": "This is some context.", + "cache_control": {"type": "ephemeral"}, + }, + ] + new = [ + { + "type": "text", + "text": "Additional instruction.", + "cache_control": {"type": "ephemeral"}, + } + ] + result = combine_system_messages(existing, new) + expected = [ + { + "type": "text", + "text": "You are an AI assistant.", + }, + { + "type": "text", + "text": "This is some context.", + "cache_control": {"type": "ephemeral"}, + }, + { + "type": "text", + "text": "Additional instruction.", + "cache_control": {"type": "ephemeral"}, + }, + ] + assert result == expected From a54ea2d01beac03fa0128112198565f55eac3f75 Mon Sep 17 00:00:00 2001 From: Ivan Leo Date: Tue, 29 Oct 2024 16:12:44 +0800 Subject: [PATCH 2/2] fix: updated scope of pytest fixture to be function based (#1131) --- tests/llm/test_openai/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/llm/test_openai/conftest.py b/tests/llm/test_openai/conftest.py index 280da5011..583de7f43 100644 --- a/tests/llm/test_openai/conftest.py +++ b/tests/llm/test_openai/conftest.py @@ -13,7 +13,7 @@ def wrap_openai(x): return x -@pytest.fixture(scope="session") +@pytest.fixture(scope="function") def client(): if os.environ.get("BRAINTRUST_API_KEY"): yield wrap_openai(