Skip to content

Commit

Permalink
use is_basemodel_subclass to detect BaseModel for with_structured_output
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Aug 6, 2024
1 parent dbad718 commit be3cff2
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass

from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model
Expand Down Expand Up @@ -679,24 +680,6 @@ class Choices(enum.Enum):
output_parser: BaseOutputParser = JsonOutputParser()
nvext_param: Dict[str, Any] = {"guided_json": schema}

elif issubclass(schema, BaseModel):
# PydanticOutputParser does not support streaming. what we do
# instead is ignore all inputs that are incomplete wrt the
# underlying Pydantic schema. if the entire input is invalid,
# we return None.
class ForgivingPydanticOutputParser(PydanticOutputParser):
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Any:
try:
return super().parse_result(result, partial=partial)
except OutputParserException:
pass
return None

output_parser = ForgivingPydanticOutputParser(pydantic_object=schema)
nvext_param = {"guided_json": schema.schema()}

elif issubclass(schema, enum.Enum):
# langchain's EnumOutputParser is not in langchain_core
# and doesn't support streaming. this is a simple implementation
Expand Down Expand Up @@ -724,6 +707,25 @@ def parse(self, response: str) -> Any:
)
output_parser = EnumOutputParser(enum=schema)
nvext_param = {"guided_choice": choices}

elif is_basemodel_subclass(schema):
# PydanticOutputParser does not support streaming. what we do
# instead is ignore all inputs that are incomplete wrt the
# underlying Pydantic schema. if the entire input is invalid,
# we return None.
class ForgivingPydanticOutputParser(PydanticOutputParser):
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Any:
try:
return super().parse_result(result, partial=partial)
except OutputParserException:
pass
return None

output_parser = ForgivingPydanticOutputParser(pydantic_object=schema)
nvext_param = {"guided_json": schema.schema()}

else:
raise ValueError(
"Schema must be a Pydantic object, a dictionary "
Expand Down

0 comments on commit be3cff2

Please sign in to comment.