Skip to content

Commit

Permalink
add test cases for json structured output
Browse files Browse the repository at this point in the history
  • Loading branch information
raspawar committed Oct 25, 2024
1 parent 52fefee commit 1652326
Showing 1 changed file with 144 additions and 0 deletions.
144 changes: 144 additions & 0 deletions libs/ai-endpoints/tests/integration_tests/test_bind_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from pydantic import BaseModel as BaseModelProper
from pydantic import Field

from langchain_nvidia_ai_endpoints import ChatNVIDIA
Expand Down Expand Up @@ -145,6 +146,26 @@ def check_response_structure(response: AIMessage) -> None:
assert len(response.tool_calls) > 0


class Joke(BaseModelProper):
"""Joke to tell user."""

setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")


class SelfEvaluation(BaseModelProper):
score: int
text: str


class JokeWithEvaluation(BaseModelProper):
"""Joke to tell user."""

setup: str
punchline: str
self_evaluation: SelfEvaluation


@pytest.mark.parametrize(
"func",
[eval_invoke, eval_stream],
Expand Down Expand Up @@ -820,3 +841,126 @@ async def test_json_mode_async(tool_model: str) -> None:
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, str)
assert json.loads(full.content) == {"a": 1}


@pytest.mark.parametrize(
("method", "strict"),
[("function_calling", True), ("json_schema", None)],
)
def test_structured_output_json_strict(
tool_model: str,
method: Literal["function_calling", "json_schema"],
strict: Optional[bool],
) -> None:
"""Test to verify structured output with strict=True."""

llm = ChatNVIDIA(model=tool_model, temperature=0)

# Pydantic class
# Type ignoring since the interface only officially supports pydantic 1
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
# We'll need to do a pass updating the type signatures.
chat = llm.with_structured_output(Joke, method=method, strict=strict)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke)

for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)

# Schema
chat = llm.with_structured_output(
Joke.model_json_schema(), method=method, strict=strict
)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}

for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}


@pytest.mark.parametrize(("method", "strict"), [("json_schema", None)])
def test_nested_structured_output_json_strict(
tool_model: str, method: Literal["json_schema"], strict: Optional[bool]
) -> None:
"""Test to verify structured output with strict=True for nested object."""

llm = ChatNVIDIA(model=tool_model, temperature=0)

# Schema
chat = llm.with_structured_output(
JokeWithEvaluation.model_json_schema(), method=method, strict=strict
)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline", "self_evaluation"}
assert set(result["self_evaluation"].keys()) == {"score", "text"}

for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline", "self_evaluation"}
assert set(chunk["self_evaluation"].keys()) == {"score", "text"}


@pytest.mark.asyncio
@pytest.mark.parametrize(
("method", "strict"),
[("function_calling", True), ("json_schema", None)],
)
async def test_structured_output_json_strict_async(
tool_model: str,
method: Literal["function_calling", "json_schema"],
strict: Optional[bool],
) -> None:
"""Test to verify structured output with strict=True (async)."""

llm = ChatNVIDIA(model=tool_model, temperature=0)

# Pydantic class
chat = llm.with_structured_output(Joke, method=method, strict=strict)
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, Joke)

async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)

# Schema
chat = llm.with_structured_output(
Joke.model_json_schema(), method=method, strict=strict
)
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}

async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}


@pytest.mark.asyncio
@pytest.mark.parametrize(("method", "strict"), [("json_schema", None)])
async def test_nested_structured_output_json_strict_async(
tool_model: str, method: Literal["json_schema"], strict: Optional[bool]
) -> None:
"""Test to verify structured output with strict=True for nested object (async)."""

llm = ChatNVIDIA(model=tool_model, temperature=0)

# Schema
chat = llm.with_structured_output(
JokeWithEvaluation.model_json_schema(), method=method, strict=strict
)
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline", "self_evaluation"}
assert set(result["self_evaluation"].keys()) == {"score", "text"}

async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline", "self_evaluation"}
assert set(chunk["self_evaluation"].keys()) == {"score", "text"}

0 comments on commit 1652326

Please sign in to comment.