Skip to content

Commit

Permalink
core[patch]: Compat pydantic 2.10 (#28308)
Browse files Browse the repository at this point in the history
pydantic 2.10 compat for langchain-core
  • Loading branch information
eyurtsev authored Nov 23, 2024
1 parent ed84d48 commit a813d11
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 113 deletions.
2 changes: 1 addition & 1 deletion libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5217,7 +5217,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
kwargs.
"""

config: RunnableConfig = Field(default_factory=dict)
config: RunnableConfig = Field(default_factory=RunnableConfig) # type: ignore
"""The config to bind to the underlying Runnable."""

config_factories: list[Callable[[RunnableConfig], RunnableConfig]] = Field(
Expand Down
13 changes: 12 additions & 1 deletion libs/core/langchain_core/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,18 @@ def get_pydantic_major_version() -> int:
return 0


def _get_pydantic_minor_version() -> int:
"""Get the minor version of Pydantic."""
try:
import pydantic

return int(pydantic.__version__.split(".")[1])
except ImportError:
return 0


PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
PYDANTIC_MINOR_VERSION = _get_pydantic_minor_version()


if PYDANTIC_MAJOR_VERSION == 1:
Expand Down Expand Up @@ -200,7 +211,7 @@ def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
name not in values or values[name] is None
) and not field_info.is_required():
if field_info.default_factory is not None:
values[name] = field_info.default_factory()
values[name] = field_info.default_factory() # type: ignore
else:
values[name] = field_info.default

Expand Down
206 changes: 107 additions & 99 deletions libs/core/poetry.lock

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
'const': 'ai',
'default': 'ai',
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'anyOf': list([
Expand Down Expand Up @@ -206,6 +207,7 @@
'const': 'AIMessageChunk',
'default': 'AIMessageChunk',
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'anyOf': list([
Expand Down Expand Up @@ -290,6 +292,7 @@
'const': 'chat',
'default': 'chat',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -364,6 +367,7 @@
'const': 'ChatMessageChunk',
'default': 'ChatMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -435,6 +439,7 @@
'const': 'function',
'default': 'function',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -497,6 +502,7 @@
'const': 'FunctionMessageChunk',
'default': 'FunctionMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -595,6 +601,7 @@
'const': 'human',
'default': 'human',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -669,6 +676,7 @@
'const': 'HumanMessageChunk',
'default': 'HumanMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -767,6 +775,7 @@
'type': dict({
'const': 'invalid_tool_call',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -892,6 +901,7 @@
'const': 'system',
'default': 'system',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -961,6 +971,7 @@
'const': 'SystemMessageChunk',
'default': 'SystemMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1009,6 +1020,7 @@
'type': dict({
'const': 'tool_call',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1087,6 +1099,7 @@
'type': dict({
'const': 'tool_call_chunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1209,6 +1222,7 @@
'const': 'tool',
'default': 'tool',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1290,6 +1304,7 @@
'const': 'ToolMessageChunk',
'default': 'ToolMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1509,6 +1524,7 @@
'const': 'ai',
'default': 'ai',
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'anyOf': list([
Expand Down Expand Up @@ -1621,6 +1637,7 @@
'const': 'AIMessageChunk',
'default': 'AIMessageChunk',
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'anyOf': list([
Expand Down Expand Up @@ -1705,6 +1722,7 @@
'const': 'chat',
'default': 'chat',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1779,6 +1797,7 @@
'const': 'ChatMessageChunk',
'default': 'ChatMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1850,6 +1869,7 @@
'const': 'function',
'default': 'function',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1912,6 +1932,7 @@
'const': 'FunctionMessageChunk',
'default': 'FunctionMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2010,6 +2031,7 @@
'const': 'human',
'default': 'human',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2084,6 +2106,7 @@
'const': 'HumanMessageChunk',
'default': 'HumanMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2182,6 +2205,7 @@
'type': dict({
'const': 'invalid_tool_call',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2307,6 +2331,7 @@
'const': 'system',
'default': 'system',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2376,6 +2401,7 @@
'const': 'SystemMessageChunk',
'default': 'SystemMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2424,6 +2450,7 @@
'type': dict({
'const': 'tool_call',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2502,6 +2529,7 @@
'type': dict({
'const': 'tool_call_chunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2624,6 +2652,7 @@
'const': 'tool',
'default': 'tool',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2705,6 +2734,7 @@
'const': 'ToolMessageChunk',
'default': 'ToolMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down
17 changes: 11 additions & 6 deletions libs/core/tests/unit_tests/prompts/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_convert_to_message,
)
from langchain_core.prompts.string import PromptTemplateFormat
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION
from tests.unit_tests.pydantic_utils import _normalize_schema


Expand Down Expand Up @@ -852,18 +853,22 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
assert prompt_all_required.optional_variables == []
with pytest.raises(ValidationError):
prompt_all_required.input_schema(input="")
assert _normalize_schema(prompt_all_required.get_input_jsonschema()) == snapshot(
name="required"
)

if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
assert _normalize_schema(
prompt_all_required.get_input_jsonschema()
) == snapshot(name="required")
prompt_optional = ChatPromptTemplate(
messages=[MessagesPlaceholder("history", optional=True), ("user", "${input}")]
)
# input variables only lists required variables
assert set(prompt_optional.input_variables) == {"input"}
prompt_optional.input_schema(input="") # won't raise error
assert _normalize_schema(prompt_optional.get_input_jsonschema()) == snapshot(
name="partial"
)

if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
assert _normalize_schema(prompt_optional.get_input_jsonschema()) == snapshot(
name="partial"
)


def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> None:
Expand Down
Loading

0 comments on commit a813d11

Please sign in to comment.