diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 3a0cd996..ab412794 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -48,8 +48,11 @@ ) from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool -from langchain_core.utils.function_calling import convert_to_openai_tool -from langchain_core.utils.pydantic import is_basemodel_subclass +from langchain_core.utils.function_calling import ( + convert_to_openai_function, + convert_to_openai_tool, +) +from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass from pydantic import BaseModel, Field, PrivateAttr from langchain_nvidia_ai_endpoints._common import _NVIDIAClient @@ -237,6 +240,28 @@ def _process_for_vlm( return inputs, extra_headers +def _convert_to_openai_response_format( + schema: Union[Dict[str, Any], Type], +) -> Union[Dict, TypeBaseModel]: + if isinstance(schema, type) and is_basemodel_subclass(schema): + return schema + + if ( + isinstance(schema, dict) + and "json_schema" in schema + and schema.get("type") == "json_schema" + ): + response_format = schema + elif isinstance(schema, dict) and "name" in schema and "schema" in schema: + response_format = {"type": "json_schema", "json_schema": schema} + else: + function = convert_to_openai_function(schema) + function["schema"] = function.pop("parameters") + response_format = {"type": "json_schema", "json_schema": function} + + return response_format + + _DEFAULT_MODEL_NAME: str = "meta/llama3-8b-instruct" @@ -647,8 +672,11 @@ def bind_functions( # as a result need to type ignore for the schema parameter and return type. def with_structured_output( # type: ignore self, - schema: Union[Dict, Type], + schema: Optional[Union[Dict, Type]] = None, *, + method: Literal[ + "function_calling", "json_mode", "json_schema" + ] = "function_calling", include_raw: bool = False, **kwargs: Any, ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: @@ -789,6 +817,8 @@ class Choices(enum.Enum): "being None when the LLM produces an incomplete response." ) + is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema) + # check if the model supports structured output, warn if it does not known_good = False # todo: we need to store model: Model in this class @@ -807,65 +837,94 @@ class Choices(enum.Enum): f"Model '{self.model}' is not known to support structured output. " "Your output may fail at inference time." ) - - if isinstance(schema, dict): - output_parser: BaseOutputParser = JsonOutputParser() - nvext_param: Dict[str, Any] = {"guided_json": schema} - - elif issubclass(schema, enum.Enum): - # langchain's EnumOutputParser is not in langchain_core - # and doesn't support streaming. this is a simple implementation - # that supports streaming with our semantics of returning None - # if no complete object can be constructed. - class EnumOutputParser(BaseOutputParser): - enum: Type[enum.Enum] - - def parse(self, response: str) -> Any: - try: - return self.enum(response.strip()) - except ValueError: - pass - return None - - # guided_choice only supports string choices - choices = [choice.value for choice in schema] - if not all(isinstance(choice, str) for choice in choices): - # instead of erroring out we could coerce the enum values to - # strings, but would then need to coerce them back to their - # original type for Enum construction. + output_parser: BaseOutputParser + + if method == "json_mode": + llm = self.bind(response_format={"type": "json_object"}) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + elif method == "json_schema": + if schema is None: raise ValueError( - "Enum schema must only contain string choices. " - "Use StrEnum or ensure all member values are strings." + "schema must be specified when method is not 'json_mode'. " + "Received None." ) - output_parser = EnumOutputParser(enum=schema) - nvext_param = {"guided_choice": choices} - - elif is_basemodel_subclass(schema): - # PydanticOutputParser does not support streaming. what we do - # instead is ignore all inputs that are incomplete wrt the - # underlying Pydantic schema. if the entire input is invalid, - # we return None. - class ForgivingPydanticOutputParser(PydanticOutputParser): - def parse_result( - self, result: List[Generation], *, partial: bool = False - ) -> Any: - try: - return super().parse_result(result, partial=partial) - except OutputParserException: - pass - return None - - output_parser = ForgivingPydanticOutputParser(pydantic_object=schema) - if hasattr(schema, "model_json_schema"): - json_schema = schema.model_json_schema() - else: - json_schema = schema.schema() - nvext_param = {"guided_json": json_schema} - + response_format = _convert_to_openai_response_format(schema) + llm = self.bind(response_format=response_format) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + elif method == "function_calling": + pass else: raise ValueError( - "Schema must be a Pydantic object, a dictionary " - "representing a JSON schema, or an Enum." + f"Unrecognized method argument. Expected one of 'json_scheme' or " + f"'json_mode'. Received: '{method}'" ) - return super().bind(nvext=nvext_param) | output_parser + if schema: + if isinstance(schema, dict): + output_parser = JsonOutputParser() + llm = self.bind(nvext={"guided_json": schema}) + elif issubclass(schema, enum.Enum): + # langchain's EnumOutputParser is not in langchain_core + # and doesn't support streaming. this is a simple implementation + # that supports streaming with our semantics of returning None + # if no complete object can be constructed. + class EnumOutputParser(BaseOutputParser): + enum: Type[enum.Enum] + + def parse(self, response: str) -> Any: + try: + return self.enum(response.strip()) + except ValueError: + pass + return None + + # guided_choice only supports string choices + choices = [choice.value for choice in schema] + if not all(isinstance(choice, str) for choice in choices): + # instead of erroring out we could coerce the enum values to + # strings, but would then need to coerce them back to their + # original type for Enum construction. + raise ValueError( + "Enum schema must only contain string choices. " + "Use StrEnum or ensure all member values are strings." + ) + output_parser = EnumOutputParser(enum=schema) + llm = self.bind(nvext={"guided_choice": choices}) + + elif is_basemodel_subclass(schema): + # PydanticOutputParser does not support streaming. what we do + # instead is ignore all inputs that are incomplete wrt the + # underlying Pydantic schema. if the entire input is invalid, + # we return None. + class ForgivingPydanticOutputParser(PydanticOutputParser): + def parse_result( + self, result: List[Generation], *, partial: bool = False + ) -> Any: + try: + return super().parse_result(result, partial=partial) + except OutputParserException: + pass + return None + + output_parser = ForgivingPydanticOutputParser(pydantic_object=schema) + if hasattr(schema, "model_json_schema"): + json_schema = schema.model_json_schema() + else: + json_schema = schema.schema() + llm = self.bind(nvext={"guided_json": json_schema}) + + else: + raise ValueError( + "Schema must be a Pydantic object, a dictionary " + "representing a JSON schema, or an Enum." + ) + + return llm | output_parser diff --git a/libs/ai-endpoints/tests/integration_tests/test_structured_output.py b/libs/ai-endpoints/tests/integration_tests/test_structured_output.py index d143b47e..8feb795e 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_structured_output.py +++ b/libs/ai-endpoints/tests/integration_tests/test_structured_output.py @@ -3,6 +3,8 @@ import pytest from langchain_core.messages import HumanMessage + +# from langchain_core.output_parsers import JsonOutputParser from pydantic import BaseModel, Field from pydantic import BaseModel as BaseModelProper @@ -212,17 +214,17 @@ def nested_json(result: Any) -> None: @pytest.mark.parametrize( ("method", "strict"), - [("function_calling", True), ("json_schema", None), ("json_mode", None)], + [("json_schema", None), ("json_mode", None)], ) def test_structured_output_json_strict( - structured_model: str, + tool_model: str, mode: dict, - method: Literal["function_calling", "json_schema", "json_mode"], + method: Literal["function_calling", "json_mode", "json_schema"], strict: Optional[bool], ) -> None: """Test to verify structured output with strict=True.""" - llm = ChatNVIDIA(model=structured_model, temperature=0, **mode) + llm = ChatNVIDIA(model=tool_model, temperature=0, **mode) # Test structured output with a Pydantic class chat = llm.with_structured_output(Joke, method=method, strict=strict) @@ -249,7 +251,10 @@ def test_structured_output_json_strict( ("method", "strict"), [("json_schema", None), ("json_mode", None)] ) def test_nested_structured_output_json_strict( - tool_model: str, mode: dict, method: Literal["json_schema"], strict: Optional[bool] + tool_model: str, + mode: dict, + method: Literal["function_calling", "json_schema", "json_mode"], + strict: Optional[bool], ) -> None: """Test to verify structured output with strict=True for nested object.""" @@ -274,7 +279,7 @@ def test_nested_structured_output_json_strict( ) async def test_structured_output_json_strict_async( tool_model: str, - method: Literal["function_calling", "json_schema"], + method: Literal["function_calling", "json_schema", "json_mode"], strict: Optional[bool], ) -> None: """Test to verify structured output with strict=True (async).""" @@ -322,3 +327,25 @@ async def test_nested_structured_output_json_strict_async( async for chunk in chat.astream("Tell me a joke about cats."): assert isinstance(chunk, dict) nested_json(chunk) + + +def test_json_mode_with_dict(tool_model: str) -> None: + """Test json_mode with a dictionary schema.""" + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + } + + llm = ChatNVIDIA(tool_model=tool_model) + llm.with_structured_output(schema, method="json_mode") + # assert isinstance(structured_llm.steps[-1], JsonOutputParser) + + +def test_json_schema_with_none_schema(tool_model: str) -> None: + """Test json_schema method with None schema raises error.""" + llm = ChatNVIDIA(tool_model=tool_model) + + with pytest.raises( + ValueError, match="schema must be specified when method is not 'json_mode'" + ): + llm.with_structured_output(schema=None, method="json_schema") diff --git a/libs/ai-endpoints/tests/unit_tests/test_structured_output.py b/libs/ai-endpoints/tests/unit_tests/test_structured_output.py index 114848e0..818b028a 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_structured_output.py +++ b/libs/ai-endpoints/tests/unit_tests/test_structured_output.py @@ -215,5 +215,5 @@ def test_strict_no_warns(strict: Optional[bool]) -> None: ChatNVIDIA(api_key="BOGUS").with_structured_output( Joke, - **({"strict": strict} if strict is not None else {}), + strict=strict, )