diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml
index eaecfbe65..2c81358ca 100644
--- a/.github/workflows/ruff.yml
+++ b/.github/workflows/ruff.yml
@@ -31,13 +31,7 @@ jobs:
python3 -m pip install -r requirements.txt
python3 -m pip install -r requirements-doc.txt
- name: Run Continuous Integration Action
- run: |
- set -e -o pipefail
- export CUSTOM_PACKAGES="${{ env.CUSTOM_PACKAGES }}" &&
- export CUSTOM_FLAGS="${{ env.CUSTOM_FLAGS }}" &&
- curl -sSL https://raw.githubusercontent.com/gao-hongnan/omniverse/2fd5de1b8103e955cd5f022ab016b72fa901fa8f/scripts/devops/continuous-integration/lint_ruff.sh -o lint_ruff.sh
- chmod +x lint_ruff.sh
- bash lint_ruff.sh | tee ${{ env.WORKING_DIRECTORY }}/${{ env.RUFF_OUTPUT_FILENAME }}
+ uses: astral-sh/ruff-action@v1
- name: Upload Artifacts
uses: actions/upload-artifact@v3
with:
diff --git a/docs/hooks/hide_lines.py b/docs/hooks/hide_lines.py
index 317f00605..7e346903f 100644
--- a/docs/hooks/hide_lines.py
+++ b/docs/hooks/hide_lines.py
@@ -5,9 +5,9 @@
@mkdocs.plugins.event_priority(0)
# pylint: disable=unused-argument
-def on_startup(command: str, dirty: bool) -> None:
+def on_startup(command: str, dirty: bool) -> None: # noqa: ARG001
"""Monkey patch Highlight extension to hide lines in code blocks."""
- original = highlight.Highlight.highlight
+ original = highlight.Highlight.highlight # type: ignore
def patched(self: Any, src: str, *args: Any, **kwargs: Any) -> Any:
lines = src.splitlines(keepends=True)
diff --git a/docs/tutorials/7-synthetic-data-generation.ipynb b/docs/tutorials/7-synthetic-data-generation.ipynb
index 74bbc0607..3bbdfce1f 100644
--- a/docs/tutorials/7-synthetic-data-generation.ipynb
+++ b/docs/tutorials/7-synthetic-data-generation.ipynb
@@ -565,7 +565,7 @@
"\n",
"for metric,size in product(METRICS,SIZES):\n",
" metric_name, score_fn = metric\n",
- " score_fns[f\"{metric_name}@{size}\"] = lambda predictions,labels : score_fn(predictions[:size],labels)"
+ " score_fns[f\"{metric_name}@{size}\"] = lambda predictions,labels, fn=score_fn, k=size: fn(predictions[:k],labels) # type: ignore"
]
},
{
diff --git a/instructor/client.py b/instructor/client.py
index fa452c884..d50659112 100644
--- a/instructor/client.py
+++ b/instructor/client.py
@@ -507,12 +507,7 @@ def from_openai(
instructor.Mode.MD_JSON,
}
- if provider in {Provider.DATABRICKS}:
- assert mode in {
- instructor.Mode.MD_JSON
- }, "Databricks provider only supports `MD_JSON` mode."
-
- if provider in {Provider.OPENAI}:
+ if provider in {Provider.OPENAI, Provider.DATABRICKS}:
assert mode in {
instructor.Mode.TOOLS,
instructor.Mode.JSON,
diff --git a/instructor/process_response.py b/instructor/process_response.py
index 3dd876f34..277151711 100644
--- a/instructor/process_response.py
+++ b/instructor/process_response.py
@@ -20,7 +20,11 @@
from instructor.dsl.partial import PartialBase
from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type
from instructor.function_calls import OpenAISchema, openai_schema
-from instructor.utils import merge_consecutive_messages
+from instructor.utils import (
+ merge_consecutive_messages,
+ extract_system_messages,
+ combine_system_messages,
+)
from instructor.multimodal import convert_messages
logger = logging.getLogger("instructor")
@@ -332,20 +336,15 @@ def handle_anthropic_tools(
"name": response_model.__name__,
}
- system_messages = [
- m["content"] for m in new_kwargs["messages"] if m["role"] == "system"
- ]
+ system_messages = extract_system_messages(new_kwargs.get("messages", []))
- if "system" in new_kwargs and system_messages:
- raise ValueError(
- "Only a single system message is supported - either set it as a message in the messages array or use the system parameter"
+ if system_messages:
+ new_kwargs["system"] = combine_system_messages(
+ new_kwargs.get("system"), system_messages
)
- if "system" not in new_kwargs:
- new_kwargs["system"] = "\n\n".join(system_messages)
-
new_kwargs["messages"] = [
- m for m in new_kwargs["messages"] if m["role"] != "system"
+ m for m in new_kwargs.get("messages", []) if m["role"] != "system"
]
return response_model, new_kwargs
@@ -354,25 +353,18 @@ def handle_anthropic_tools(
def handle_anthropic_json(
response_model: type[T], new_kwargs: dict[str, Any]
) -> tuple[type[T], dict[str, Any]]:
- openai_system_messages = [
- message["content"]
- for message in new_kwargs.get("messages", [])
- if message["role"] == "system"
- ]
+ system_messages = extract_system_messages(new_kwargs.get("messages", []))
- if "system" in new_kwargs and openai_system_messages:
- raise ValueError(
- "Only a single System message is supported - either set it using the system parameter or in the list of messages"
+ if system_messages:
+ new_kwargs["system"] = combine_system_messages(
+ new_kwargs.get("system"), system_messages
)
- if not "system" in new_kwargs:
- new_kwargs["system"] = "\n\n".join(openai_system_messages)
-
new_kwargs["messages"] = [
- m for m in new_kwargs["messages"] if m["role"] != "system"
+ m for m in new_kwargs.get("messages", []) if m["role"] != "system"
]
- message = dedent(
+ json_schema_message = dedent(
f"""
As a genius expert, your task is to understand the content and provide
the parsed objects in json that match the following json_schema:\n
@@ -383,7 +375,9 @@ def handle_anthropic_json(
"""
)
- new_kwargs["system"] = f"{new_kwargs.get('system', '')}\n\n{message}".strip()
+ new_kwargs["system"] = combine_system_messages(
+ new_kwargs.get("system"), [{"type": "text", "text": json_schema_message}]
+ )
return response_model, new_kwargs
@@ -664,7 +658,7 @@ def handle_response_model(
# This is cause cohere uses 'message' and 'chat_history' instead of 'messages'
return handle_cohere_modes(new_kwargs)
# Handle images without a response model
- if autodetect_images and "messages" in new_kwargs:
+ if "messages" in new_kwargs:
messages = convert_messages(
new_kwargs["messages"],
mode,
@@ -672,14 +666,11 @@ def handle_response_model(
)
if mode in {Mode.ANTHROPIC_JSON, Mode.ANTHROPIC_TOOLS}:
# Handle OpenAI style or Anthropic style messages
- new_kwargs["messages"] = [
- m for m in messages if m["role"] != "system"
- ]
+ new_kwargs["messages"] = [m for m in messages if m["role"] != "system"]
if "system" not in new_kwargs:
- system_messages = (m for m in messages if m["role"] == "system")
- system_message = next(system_messages, None)
+ system_message = extract_system_messages(messages)
if system_message:
- new_kwargs["system"] = system_message["content"]
+ new_kwargs["system"] = system_message
else:
new_kwargs["messages"] = messages
return None, new_kwargs
diff --git a/instructor/utils.py b/instructor/utils.py
index e9cacac25..6bee2b1c6 100644
--- a/instructor/utils.py
+++ b/instructor/utils.py
@@ -10,7 +10,10 @@
Callable,
Generic,
Protocol,
+ Union,
+ TypedDict,
TypeVar,
+ cast
)
from pydantic import BaseModel
import os
@@ -131,7 +134,6 @@ def update_total_usage(
response: T_Model | None,
total_usage: OpenAIUsage | AnthropicUsage,
) -> T_Model | ChatCompletion | None:
-
if response is None:
return None
@@ -369,3 +371,52 @@ def update_gemini_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
def disable_pydantic_error_url():
os.environ["PYDANTIC_ERRORS_INCLUDE_URL"] = "0"
+
+
+class SystemMessage(TypedDict, total=False):
+ type: str
+ text: str
+ cache_control: dict[str, str]
+
+
+def combine_system_messages(
+ existing_system: Union[str, list[SystemMessage], None], # noqa: UP007
+ new_system: Union[str, list[SystemMessage]], # noqa: UP007
+) -> Union[str, list[SystemMessage]]: # noqa: UP007
+ if existing_system is None:
+ return new_system
+
+ if isinstance(existing_system, str) and isinstance(new_system, str):
+ return f"{existing_system}\n\n{new_system}"
+
+ if isinstance(existing_system, list) and isinstance(new_system, list):
+ return existing_system + new_system
+
+ if isinstance(existing_system, str) and isinstance(new_system, list):
+ return [SystemMessage(type="text", text=existing_system)] + new_system
+
+ if isinstance(existing_system, list) and isinstance(new_system, str):
+ return existing_system + [SystemMessage(type="text", text=new_system)]
+
+ raise ValueError("Unsupported system message type combination")
+
+
+def extract_system_messages(messages: list[dict[str, Any]]) -> list[SystemMessage]:
+ def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: # noqa: UP007
+ if isinstance(content, str):
+ return SystemMessage(type="text", text=content)
+ elif isinstance(content, dict):
+ return SystemMessage(**content)
+ else:
+ raise ValueError(f"Unsupported content type: {type(content)}")
+
+ result: list[SystemMessage] = []
+ for m in messages:
+ if m["role"] == "system":
+ # System message must always be a string or list of dictionaries
+ content = cast(Union[str, list[dict[str, Any]]], m["content"]) # noqa: UP007
+ if isinstance(content, list):
+ result.extend(convert_message(item) for item in content)
+ else:
+ result.append(convert_message(content))
+ return result
diff --git a/make_desc.py b/make_desc.py
index 9a4a896c1..6dc9211aa 100644
--- a/make_desc.py
+++ b/make_desc.py
@@ -1,5 +1,5 @@
import os
-from typing import Optional, List, Set, Literal
+from typing import Optional, Literal
import asyncio
from openai import AsyncOpenAI
import typer
@@ -15,7 +15,7 @@
async def generate_ai_frontmatter(
- client: AsyncOpenAI, title: str, content: str, categories: List[str]
+ client: AsyncOpenAI, title: str, content: str, categories: list[str]
):
"""
Generate a description and categories for the given content using AI.
@@ -35,8 +35,8 @@ class DescriptionAndCategories(BaseModel):
reasoning: str = Field(
..., description="The reasoning for the correct categories"
)
- tags: List[str]
- categories: List[
+ tags: list[str]
+ categories: list[
Literal[
"OpenAI",
"Anthropic",
@@ -72,7 +72,7 @@ class DescriptionAndCategories(BaseModel):
return response
-def get_all_categories(root_dir: str) -> Set[str]:
+def get_all_categories(root_dir: str) -> set[str]:
"""
Read all markdown files and extract unique categories.
@@ -113,7 +113,7 @@ def preview_categories(root_dir: str) -> None:
async def process_file(
- client: AsyncOpenAI, file_path: str, categories: List[str], enable_comments: bool
+ client: AsyncOpenAI, file_path: str, categories: list[str], enable_comments: bool
) -> None:
"""
Process a single file, adding or updating the description and categories in the front matter.
@@ -143,7 +143,7 @@ async def process_file(
async def process_files(
root_dir: str,
- api_key: Optional[str] = None,
+ api_key: Optional[str] = None, # noqa: ARG001
use_categories: bool = False,
enable_comments: bool = False,
) -> None:
diff --git a/tests/llm/test_anthropic/evals/test_simple.py b/tests/llm/test_anthropic/evals/test_simple.py
index 2ebdadfd9..62134c739 100644
--- a/tests/llm/test_anthropic/evals/test_simple.py
+++ b/tests/llm/test_anthropic/evals/test_simple.py
@@ -20,18 +20,19 @@ class User(BaseModel):
@field_validator("name")
def name_is_uppercase(cls, v: str):
- assert v.isupper(), "Name must be uppercase, please fix"
+ assert v.isupper(), f"{v} is not an uppercased string. Note that all characters in {v} must be uppercase (EG. TIM SARAH ADAM)."
return v
resp = client.messages.create(
model="claude-3-haiku-20240307",
- max_tokens=1024,
+ max_tokens=4096,
max_retries=2,
+ system="Make sure to follow the instructions carefully and return a response object that matches the json schema requested. Age is an integer.",
messages=[
{
"role": "user",
"content": "Extract John is 18 years old.",
- }
+ },
],
response_model=User,
) # type: ignore
@@ -53,7 +54,7 @@ class User(BaseModel):
resp = client.messages.create(
model="claude-3-haiku-20240307",
- max_tokens=1024,
+ max_tokens=4096,
max_retries=0,
messages=[
{
@@ -83,6 +84,7 @@ class User(BaseModel):
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
+ system="Make sure to follow the instructions carefully and return a response object that matches the json schema requested. Family members here is just asking for a list of names",
messages=[
{
"role": "user",
@@ -132,7 +134,7 @@ class User(BaseModel):
resp = client.messages.create(
model="claude-3-haiku-20240307",
- max_tokens=1024,
+ max_tokens=4096,
max_retries=2,
messages=[
{
@@ -185,7 +187,10 @@ class User(BaseModel):
max_tokens=1024,
max_retries=0,
messages=[
- {"role": "system", "content": "EVERYTHING MUST BE IN ALL CAPS"},
+ {
+ "role": "system",
+ "content": "Please make sure to follow the instructions carefully and return a valid response object. All strings must be fully capitalised in all caps. (Eg. THIS IS AN UPPERCASE STRING) and age is an integer.",
+ },
{
"role": "user",
"content": "Create a user for a model with a name and age.",
diff --git a/tests/llm/test_anthropic/test_system.py b/tests/llm/test_anthropic/test_system.py
new file mode 100644
index 000000000..eb3158902
--- /dev/null
+++ b/tests/llm/test_anthropic/test_system.py
@@ -0,0 +1,144 @@
+import pytest
+import instructor
+from pydantic import BaseModel
+from itertools import product
+from .util import models, modes
+from anthropic.types.message import Message
+
+
+class User(BaseModel):
+ name: str
+ age: int
+
+
+@pytest.mark.parametrize("model, mode", product(models, modes))
+def test_creation(model, mode, client):
+ client = instructor.from_anthropic(client, mode=mode)
+ response = client.chat.completions.create(
+ model=model,
+ response_model=User,
+ messages=[
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "Mike is 37 years old"}
+ ],
+ },
+ {
+ "role": "user",
+ "content": "Extract a user from the story.",
+ },
+ ],
+ temperature=1,
+ max_tokens=1000,
+ )
+
+ # Assertions to validate the response
+ assert isinstance(response, User)
+ assert response.name == "Mike"
+ assert response.age == 37
+
+
+@pytest.mark.parametrize("model, mode", product(models, modes))
+def test_creation_with_system_cache(model, mode, client):
+ client = instructor.from_anthropic(client, mode=mode, enable_prompt_caching=True)
+ response, message = client.chat.completions.create_with_completion(
+ model=model,
+ response_model=User,
+ messages=[
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "Mike is 37 years old " * 200 + "",
+ "cache_control": {"type": "ephemeral"},
+ },
+ {
+ "type": "text",
+ "text": "You are a helpful assistant who extracts users from stories.",
+ },
+ ],
+ },
+ {
+ "role": "user",
+ "content": "Extract a user from the story.",
+ },
+ ],
+ temperature=1,
+ max_tokens=1000,
+ )
+
+ # Assertions to validate the response
+ assert isinstance(response, User)
+ assert response.name == "Mike"
+ assert response.age == 37
+
+ # Assert a cache write or cache hit
+ assert (
+ message.usage.cache_creation_input_tokens > 0
+ or message.usage.cache_read_input_tokens > 0
+ )
+
+
+@pytest.mark.parametrize("model, mode", product(models, modes))
+def test_creation_with_system_cache_anthropic_style(model, mode, client):
+ client = instructor.from_anthropic(client, mode=mode, enable_prompt_caching=True)
+ response, message = client.chat.completions.create_with_completion(
+ model=model,
+ system=[
+ {
+ "type": "text",
+ "text": "Mike is 37 years old " * 200 + "",
+ "cache_control": {"type": "ephemeral"},
+ },
+ {
+ "type": "text",
+ "text": "You are a helpful assistant who extracts users from stories.",
+ },
+ ],
+ response_model=User,
+ messages=[
+ {
+ "role": "user",
+ "content": "Extract a user from the story.",
+ },
+ ],
+ temperature=1,
+ max_tokens=1000,
+ )
+
+ # Assertions to validate the response
+ assert isinstance(response, User)
+ assert response.name == "Mike"
+ assert response.age == 37
+
+ # Assert a cache write or cache hit
+ assert (
+ message.usage.cache_creation_input_tokens > 0
+ or message.usage.cache_read_input_tokens > 0
+ )
+
+
+@pytest.mark.parametrize("model, mode", product(models, modes))
+def test_creation_no_response_model(model, mode, client):
+ client = instructor.from_anthropic(client, mode=mode)
+ response = client.chat.completions.create(
+ response_model=None,
+ model=model,
+ messages=[
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": "Mike is 37 years old"}],
+ },
+ {
+ "role": "user",
+ "content": "Extract a user from the story.",
+ },
+ ],
+ temperature=1,
+ max_tokens=1000,
+ )
+
+ # Assertions to validate the response
+ assert isinstance(response, Message)
diff --git a/tests/llm/test_openai/conftest.py b/tests/llm/test_openai/conftest.py
index 280da5011..583de7f43 100644
--- a/tests/llm/test_openai/conftest.py
+++ b/tests/llm/test_openai/conftest.py
@@ -13,7 +13,7 @@ def wrap_openai(x):
return x
-@pytest.fixture(scope="session")
+@pytest.fixture(scope="function")
def client():
if os.environ.get("BRAINTRUST_API_KEY"):
yield wrap_openai(
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 9dead016b..6fa7ff36f 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -6,6 +6,8 @@
extract_json_from_stream,
extract_json_from_stream_async,
merge_consecutive_messages,
+ extract_system_messages,
+ combine_system_messages,
)
@@ -185,3 +187,202 @@ def my_property(cls):
return cls.clvar
assert MyClass.my_property == 1
+
+
+def test_combine_system_messages_string_string():
+ existing = "Existing message"
+ new = "New message"
+ result = combine_system_messages(existing, new)
+ assert result == "Existing message\n\nNew message"
+
+
+def test_combine_system_messages_list_list():
+ existing = [{"type": "text", "text": "Existing"}]
+ new = [{"type": "text", "text": "New"}]
+ result = combine_system_messages(existing, new)
+ assert result == [
+ {"type": "text", "text": "Existing"},
+ {"type": "text", "text": "New"},
+ ]
+
+
+def test_combine_system_messages_string_list():
+ existing = "Existing"
+ new = [{"type": "text", "text": "New"}]
+ result = combine_system_messages(existing, new)
+ assert result == [
+ {"type": "text", "text": "Existing"},
+ {"type": "text", "text": "New"},
+ ]
+
+
+def test_combine_system_messages_list_string():
+ existing = [{"type": "text", "text": "Existing"}]
+ new = "New"
+ result = combine_system_messages(existing, new)
+ assert result == [
+ {"type": "text", "text": "Existing"},
+ {"type": "text", "text": "New"},
+ ]
+
+
+def test_combine_system_messages_none_string():
+ existing = None
+ new = "New"
+ result = combine_system_messages(existing, new)
+ assert result == "New"
+
+
+def test_combine_system_messages_none_list():
+ existing = None
+ new = [{"type": "text", "text": "New"}]
+ result = combine_system_messages(existing, new)
+ assert result == [{"type": "text", "text": "New"}]
+
+
+def test_combine_system_messages_invalid_type():
+ with pytest.raises(ValueError):
+ combine_system_messages(123, "New")
+
+
+def test_extract_system_messages():
+ messages = [
+ {"role": "system", "content": "System message 1"},
+ {"role": "user", "content": "User message"},
+ {"role": "system", "content": "System message 2"},
+ ]
+ result = extract_system_messages(messages)
+ expected = [
+ {"type": "text", "text": "System message 1"},
+ {"type": "text", "text": "System message 2"},
+ ]
+ assert result == expected
+
+
+def test_extract_system_messages_no_system():
+ messages = [
+ {"role": "user", "content": "User message"},
+ {"role": "assistant", "content": "Assistant message"},
+ ]
+ result = extract_system_messages(messages)
+ assert result == []
+
+
+def test_combine_system_messages_with_cache_control():
+ existing = [
+ {
+ "type": "text",
+ "text": "You are an AI assistant.",
+ },
+ {
+ "type": "text",
+ "text": "This is some context.",
+ "cache_control": {"type": "ephemeral"},
+ },
+ ]
+ new = "Provide insightful analysis."
+ result = combine_system_messages(existing, new)
+ expected = [
+ {
+ "type": "text",
+ "text": "You are an AI assistant.",
+ },
+ {
+ "type": "text",
+ "text": "This is some context.",
+ "cache_control": {"type": "ephemeral"},
+ },
+ {"type": "text", "text": "Provide insightful analysis."},
+ ]
+ assert result == expected
+
+
+def test_combine_system_messages_string_to_cache_control():
+ existing = "You are an AI assistant."
+ new = [
+ {
+ "type": "text",
+ "text": "Analyze this text:",
+ "cache_control": {"type": "ephemeral"},
+ },
+ {"type": "text", "text": ""},
+ ]
+ result = combine_system_messages(existing, new)
+ expected = [
+ {"type": "text", "text": "You are an AI assistant."},
+ {
+ "type": "text",
+ "text": "Analyze this text:",
+ "cache_control": {"type": "ephemeral"},
+ },
+ {"type": "text", "text": ""},
+ ]
+ assert result == expected
+
+
+def test_extract_system_messages_with_cache_control():
+ messages = [
+ {"role": "system", "content": "You are an AI assistant."},
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "Analyze this text:",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ {"role": "user", "content": "User message"},
+ {"role": "system", "content": ""},
+ ]
+ result = extract_system_messages(messages)
+ expected = [
+ {"type": "text", "text": "You are an AI assistant."},
+ {
+ "type": "text",
+ "text": "Analyze this text:",
+ "cache_control": {"type": "ephemeral"},
+ },
+ {"type": "text", "text": ""},
+ ]
+ assert result == expected
+
+
+def test_combine_system_messages_preserve_cache_control():
+ existing = [
+ {
+ "type": "text",
+ "text": "You are an AI assistant.",
+ },
+ {
+ "type": "text",
+ "text": "This is some context.",
+ "cache_control": {"type": "ephemeral"},
+ },
+ ]
+ new = [
+ {
+ "type": "text",
+ "text": "Additional instruction.",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ]
+ result = combine_system_messages(existing, new)
+ expected = [
+ {
+ "type": "text",
+ "text": "You are an AI assistant.",
+ },
+ {
+ "type": "text",
+ "text": "This is some context.",
+ "cache_control": {"type": "ephemeral"},
+ },
+ {
+ "type": "text",
+ "text": "Additional instruction.",
+ "cache_control": {"type": "ephemeral"},
+ },
+ ]
+ assert result == expected