Skip to content

Commit

Permalink
handle method parameter for json
Browse files Browse the repository at this point in the history
  • Loading branch information
raspawar committed Nov 13, 2024
1 parent b115909 commit 2305f00
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 66 deletions.
177 changes: 118 additions & 59 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from pydantic import BaseModel, Field, PrivateAttr

from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
Expand Down Expand Up @@ -237,6 +240,28 @@ def _process_for_vlm(
return inputs, extra_headers


def _convert_to_openai_response_format(
schema: Union[Dict[str, Any], Type],
) -> Union[Dict, TypeBaseModel]:
if isinstance(schema, type) and is_basemodel_subclass(schema):
return schema

if (
isinstance(schema, dict)
and "json_schema" in schema
and schema.get("type") == "json_schema"
):
response_format = schema
elif isinstance(schema, dict) and "name" in schema and "schema" in schema:
response_format = {"type": "json_schema", "json_schema": schema}
else:
function = convert_to_openai_function(schema)
function["schema"] = function.pop("parameters")
response_format = {"type": "json_schema", "json_schema": function}

return response_format


_DEFAULT_MODEL_NAME: str = "meta/llama3-8b-instruct"


Expand Down Expand Up @@ -647,8 +672,11 @@ def bind_functions(
# as a result need to type ignore for the schema parameter and return type.
def with_structured_output( # type: ignore
self,
schema: Union[Dict, Type],
schema: Optional[Union[Dict, Type]] = None,
*,
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
Expand Down Expand Up @@ -789,6 +817,8 @@ class Choices(enum.Enum):
"being None when the LLM produces an incomplete response."
)

is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)

# check if the model supports structured output, warn if it does not
known_good = False
# todo: we need to store model: Model in this class
Expand All @@ -807,65 +837,94 @@ class Choices(enum.Enum):
f"Model '{self.model}' is not known to support structured output. "
"Your output may fail at inference time."
)

if isinstance(schema, dict):
output_parser: BaseOutputParser = JsonOutputParser()
nvext_param: Dict[str, Any] = {"guided_json": schema}

elif issubclass(schema, enum.Enum):
# langchain's EnumOutputParser is not in langchain_core
# and doesn't support streaming. this is a simple implementation
# that supports streaming with our semantics of returning None
# if no complete object can be constructed.
class EnumOutputParser(BaseOutputParser):
enum: Type[enum.Enum]

def parse(self, response: str) -> Any:
try:
return self.enum(response.strip())
except ValueError:
pass
return None

# guided_choice only supports string choices
choices = [choice.value for choice in schema]
if not all(isinstance(choice, str) for choice in choices):
# instead of erroring out we could coerce the enum values to
# strings, but would then need to coerce them back to their
# original type for Enum construction.
output_parser: BaseOutputParser

if method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
elif method == "json_schema":
if schema is None:
raise ValueError(
"Enum schema must only contain string choices. "
"Use StrEnum or ensure all member values are strings."
"schema must be specified when method is not 'json_mode'. "
"Received None."
)
output_parser = EnumOutputParser(enum=schema)
nvext_param = {"guided_choice": choices}

elif is_basemodel_subclass(schema):
# PydanticOutputParser does not support streaming. what we do
# instead is ignore all inputs that are incomplete wrt the
# underlying Pydantic schema. if the entire input is invalid,
# we return None.
class ForgivingPydanticOutputParser(PydanticOutputParser):
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Any:
try:
return super().parse_result(result, partial=partial)
except OutputParserException:
pass
return None

output_parser = ForgivingPydanticOutputParser(pydantic_object=schema)
if hasattr(schema, "model_json_schema"):
json_schema = schema.model_json_schema()
else:
json_schema = schema.schema()
nvext_param = {"guided_json": json_schema}

response_format = _convert_to_openai_response_format(schema)
llm = self.bind(response_format=response_format)
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
elif method == "function_calling":
pass
else:
raise ValueError(
"Schema must be a Pydantic object, a dictionary "
"representing a JSON schema, or an Enum."
f"Unrecognized method argument. Expected one of 'json_scheme' or "
f"'json_mode'. Received: '{method}'"
)

return super().bind(nvext=nvext_param) | output_parser
if schema:
if isinstance(schema, dict):
output_parser = JsonOutputParser()
llm = self.bind(nvext={"guided_json": schema})
elif issubclass(schema, enum.Enum):
# langchain's EnumOutputParser is not in langchain_core
# and doesn't support streaming. this is a simple implementation
# that supports streaming with our semantics of returning None
# if no complete object can be constructed.
class EnumOutputParser(BaseOutputParser):
enum: Type[enum.Enum]

def parse(self, response: str) -> Any:
try:
return self.enum(response.strip())
except ValueError:
pass
return None

# guided_choice only supports string choices
choices = [choice.value for choice in schema]
if not all(isinstance(choice, str) for choice in choices):
# instead of erroring out we could coerce the enum values to
# strings, but would then need to coerce them back to their
# original type for Enum construction.
raise ValueError(
"Enum schema must only contain string choices. "
"Use StrEnum or ensure all member values are strings."
)
output_parser = EnumOutputParser(enum=schema)
llm = self.bind(nvext={"guided_choice": choices})

elif is_basemodel_subclass(schema):
# PydanticOutputParser does not support streaming. what we do
# instead is ignore all inputs that are incomplete wrt the
# underlying Pydantic schema. if the entire input is invalid,
# we return None.
class ForgivingPydanticOutputParser(PydanticOutputParser):
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Any:
try:
return super().parse_result(result, partial=partial)
except OutputParserException:
pass
return None

output_parser = ForgivingPydanticOutputParser(pydantic_object=schema)
if hasattr(schema, "model_json_schema"):
json_schema = schema.model_json_schema()
else:
json_schema = schema.schema()
llm = self.bind(nvext={"guided_json": json_schema})

else:
raise ValueError(
"Schema must be a Pydantic object, a dictionary "
"representing a JSON schema, or an Enum."
)

return llm | output_parser
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import pytest
from langchain_core.messages import HumanMessage

# from langchain_core.output_parsers import JsonOutputParser
from pydantic import BaseModel, Field
from pydantic import BaseModel as BaseModelProper

Expand Down Expand Up @@ -212,17 +214,17 @@ def nested_json(result: Any) -> None:

@pytest.mark.parametrize(
("method", "strict"),
[("function_calling", True), ("json_schema", None), ("json_mode", None)],
[("json_schema", None), ("json_mode", None)],
)
def test_structured_output_json_strict(
structured_model: str,
tool_model: str,
mode: dict,
method: Literal["function_calling", "json_schema", "json_mode"],
method: Literal["function_calling", "json_mode", "json_schema"],
strict: Optional[bool],
) -> None:
"""Test to verify structured output with strict=True."""

llm = ChatNVIDIA(model=structured_model, temperature=0, **mode)
llm = ChatNVIDIA(model=tool_model, temperature=0, **mode)

# Test structured output with a Pydantic class
chat = llm.with_structured_output(Joke, method=method, strict=strict)
Expand All @@ -249,7 +251,10 @@ def test_structured_output_json_strict(
("method", "strict"), [("json_schema", None), ("json_mode", None)]
)
def test_nested_structured_output_json_strict(
tool_model: str, mode: dict, method: Literal["json_schema"], strict: Optional[bool]
tool_model: str,
mode: dict,
method: Literal["function_calling", "json_schema", "json_mode"],
strict: Optional[bool],
) -> None:
"""Test to verify structured output with strict=True for nested object."""

Expand All @@ -274,7 +279,7 @@ def test_nested_structured_output_json_strict(
)
async def test_structured_output_json_strict_async(
tool_model: str,
method: Literal["function_calling", "json_schema"],
method: Literal["function_calling", "json_schema", "json_mode"],
strict: Optional[bool],
) -> None:
"""Test to verify structured output with strict=True (async)."""
Expand Down Expand Up @@ -322,3 +327,25 @@ async def test_nested_structured_output_json_strict_async(
async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
nested_json(chunk)


def test_json_mode_with_dict(tool_model: str) -> None:
"""Test json_mode with a dictionary schema."""
schema = {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
}

llm = ChatNVIDIA(tool_model=tool_model)
llm.with_structured_output(schema, method="json_mode")
# assert isinstance(structured_llm.steps[-1], JsonOutputParser)


def test_json_schema_with_none_schema(tool_model: str) -> None:
"""Test json_schema method with None schema raises error."""
llm = ChatNVIDIA(tool_model=tool_model)

with pytest.raises(
ValueError, match="schema must be specified when method is not 'json_mode'"
):
llm.with_structured_output(schema=None, method="json_schema")
Original file line number Diff line number Diff line change
Expand Up @@ -215,5 +215,5 @@ def test_strict_no_warns(strict: Optional[bool]) -> None:

ChatNVIDIA(api_key="BOGUS").with_structured_output(
Joke,
**({"strict": strict} if strict is not None else {}),
strict=strict,
)

0 comments on commit 2305f00

Please sign in to comment.