Skip to content

Commit

Permalink
Merge pull request #84 from langchain-ai/mattf/align-pydantic-support
Browse files Browse the repository at this point in the history
support pydantic v2 and v1 for with_structured_output
  • Loading branch information
mattf authored Aug 7, 2024
2 parents e07839b + be3cff2 commit e13519e
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 22 deletions.
38 changes: 20 additions & 18 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass

from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model
Expand Down Expand Up @@ -679,24 +680,6 @@ class Choices(enum.Enum):
output_parser: BaseOutputParser = JsonOutputParser()
nvext_param: Dict[str, Any] = {"guided_json": schema}

elif issubclass(schema, BaseModel):
# PydanticOutputParser does not support streaming. what we do
# instead is ignore all inputs that are incomplete wrt the
# underlying Pydantic schema. if the entire input is invalid,
# we return None.
class ForgivingPydanticOutputParser(PydanticOutputParser):
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Any:
try:
return super().parse_result(result, partial=partial)
except OutputParserException:
pass
return None

output_parser = ForgivingPydanticOutputParser(pydantic_object=schema)
nvext_param = {"guided_json": schema.schema()}

elif issubclass(schema, enum.Enum):
# langchain's EnumOutputParser is not in langchain_core
# and doesn't support streaming. this is a simple implementation
Expand Down Expand Up @@ -724,6 +707,25 @@ def parse(self, response: str) -> Any:
)
output_parser = EnumOutputParser(enum=schema)
nvext_param = {"guided_choice": choices}

elif is_basemodel_subclass(schema):
# PydanticOutputParser does not support streaming. what we do
# instead is ignore all inputs that are incomplete wrt the
# underlying Pydantic schema. if the entire input is invalid,
# we return None.
class ForgivingPydanticOutputParser(PydanticOutputParser):
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Any:
try:
return super().parse_result(result, partial=partial)
except OutputParserException:
pass
return None

output_parser = ForgivingPydanticOutputParser(pydantic_object=schema)
nvext_param = {"guided_json": schema.schema()}

else:
raise ValueError(
"Schema must be a Pydantic object, a dictionary "
Expand Down
2 changes: 1 addition & 1 deletion libs/ai-endpoints/scripts/check_pydantic.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fi
repository_path="$1"

# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic' | grep -v "# ignore: check_pydantic")

# Check if any matching lines were found
if [ -n "$result" ]; then
Expand Down
59 changes: 56 additions & 3 deletions libs/ai-endpoints/tests/unit_tests/test_structured_output.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import enum
import warnings
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Type

import pytest
from langchain_core.pydantic_v1 import BaseModel, Field
import requests_mock
from langchain_core.pydantic_v1 import BaseModel as lc_pydanticV1BaseModel
from langchain_core.pydantic_v1 import Field
from pydantic import BaseModel as pydanticV2BaseModel # ignore: check_pydantic
from pydantic.v1 import BaseModel as pydanticV1BaseModel # ignore: check_pydantic

from langchain_nvidia_ai_endpoints import ChatNVIDIA


class Joke(BaseModel):
class Joke(lc_pydanticV1BaseModel):
"""Joke to tell user."""

setup: str = Field(description="The setup of the joke")
Expand Down Expand Up @@ -136,3 +140,52 @@ def test_stream_enum_incomplete(
for chunk in structured_llm.stream("This is ignored."):
response = chunk
assert response is None


@pytest.mark.parametrize(
"pydanticBaseModel",
[
lc_pydanticV1BaseModel,
pydanticV1BaseModel,
pydanticV2BaseModel,
],
ids=["lc-pydantic-v1", "pydantic-v1", "pydantic-v2"],
)
def test_pydantic_version(
requests_mock: requests_mock.Mocker,
pydanticBaseModel: Type,
) -> None:
requests_mock.post(
"https://integrate.api.nvidia.com/v1/chat/completions",
json={
"id": "chatcmpl-ID",
"object": "chat.completion",
"created": 1234567890,
"model": "BOGUS",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": '{"name": "Sam Doe"}',
},
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 22,
"completion_tokens": 20,
"total_tokens": 42,
},
"system_fingerprint": None,
},
)

class Person(pydanticBaseModel): # type: ignore
name: str

llm = ChatNVIDIA(api_key="BOGUS").with_structured_output(Person)
response = llm.invoke("This is ignored.")
assert isinstance(response, Person)
assert response.name == "Sam Doe"

0 comments on commit e13519e

Please sign in to comment.