Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: Compat pydantic 2.10 #28308

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5217,7 +5217,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
kwargs.
"""

config: RunnableConfig = Field(default_factory=dict)
config: RunnableConfig = Field(default_factory=RunnableConfig) # type: ignore
"""The config to bind to the underlying Runnable."""

config_factories: list[Callable[[RunnableConfig], RunnableConfig]] = Field(
Expand Down
13 changes: 12 additions & 1 deletion libs/core/langchain_core/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,18 @@ def get_pydantic_major_version() -> int:
return 0


def _get_pydantic_minor_version() -> int:
"""Get the minor version of Pydantic."""
try:
import pydantic

return int(pydantic.__version__.split(".")[1])
except ImportError:
return 0


PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
PYDANTIC_MINOR_VERSION = _get_pydantic_minor_version()


if PYDANTIC_MAJOR_VERSION == 1:
Expand Down Expand Up @@ -200,7 +211,7 @@ def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
name not in values or values[name] is None
) and not field_info.is_required():
if field_info.default_factory is not None:
values[name] = field_info.default_factory()
values[name] = field_info.default_factory() # type: ignore
else:
values[name] = field_info.default

Expand Down
206 changes: 107 additions & 99 deletions libs/core/poetry.lock

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
'const': 'ai',
'default': 'ai',
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'anyOf': list([
Expand Down Expand Up @@ -206,6 +207,7 @@
'const': 'AIMessageChunk',
'default': 'AIMessageChunk',
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'anyOf': list([
Expand Down Expand Up @@ -290,6 +292,7 @@
'const': 'chat',
'default': 'chat',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -364,6 +367,7 @@
'const': 'ChatMessageChunk',
'default': 'ChatMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -435,6 +439,7 @@
'const': 'function',
'default': 'function',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -497,6 +502,7 @@
'const': 'FunctionMessageChunk',
'default': 'FunctionMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -595,6 +601,7 @@
'const': 'human',
'default': 'human',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -669,6 +676,7 @@
'const': 'HumanMessageChunk',
'default': 'HumanMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -767,6 +775,7 @@
'type': dict({
'const': 'invalid_tool_call',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -892,6 +901,7 @@
'const': 'system',
'default': 'system',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -961,6 +971,7 @@
'const': 'SystemMessageChunk',
'default': 'SystemMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1009,6 +1020,7 @@
'type': dict({
'const': 'tool_call',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1087,6 +1099,7 @@
'type': dict({
'const': 'tool_call_chunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1209,6 +1222,7 @@
'const': 'tool',
'default': 'tool',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1290,6 +1304,7 @@
'const': 'ToolMessageChunk',
'default': 'ToolMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1509,6 +1524,7 @@
'const': 'ai',
'default': 'ai',
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'anyOf': list([
Expand Down Expand Up @@ -1621,6 +1637,7 @@
'const': 'AIMessageChunk',
'default': 'AIMessageChunk',
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'anyOf': list([
Expand Down Expand Up @@ -1705,6 +1722,7 @@
'const': 'chat',
'default': 'chat',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1779,6 +1797,7 @@
'const': 'ChatMessageChunk',
'default': 'ChatMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1850,6 +1869,7 @@
'const': 'function',
'default': 'function',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -1912,6 +1932,7 @@
'const': 'FunctionMessageChunk',
'default': 'FunctionMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2010,6 +2031,7 @@
'const': 'human',
'default': 'human',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2084,6 +2106,7 @@
'const': 'HumanMessageChunk',
'default': 'HumanMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2182,6 +2205,7 @@
'type': dict({
'const': 'invalid_tool_call',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2307,6 +2331,7 @@
'const': 'system',
'default': 'system',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2376,6 +2401,7 @@
'const': 'SystemMessageChunk',
'default': 'SystemMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2424,6 +2450,7 @@
'type': dict({
'const': 'tool_call',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2502,6 +2529,7 @@
'type': dict({
'const': 'tool_call_chunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2624,6 +2652,7 @@
'const': 'tool',
'default': 'tool',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down Expand Up @@ -2705,6 +2734,7 @@
'const': 'ToolMessageChunk',
'default': 'ToolMessageChunk',
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
Expand Down
17 changes: 11 additions & 6 deletions libs/core/tests/unit_tests/prompts/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_convert_to_message,
)
from langchain_core.prompts.string import PromptTemplateFormat
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION
from tests.unit_tests.pydantic_utils import _normalize_schema


Expand Down Expand Up @@ -852,18 +853,22 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
assert prompt_all_required.optional_variables == []
with pytest.raises(ValidationError):
prompt_all_required.input_schema(input="")
assert _normalize_schema(prompt_all_required.get_input_jsonschema()) == snapshot(
name="required"
)

if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
assert _normalize_schema(
prompt_all_required.get_input_jsonschema()
) == snapshot(name="required")
prompt_optional = ChatPromptTemplate(
messages=[MessagesPlaceholder("history", optional=True), ("user", "${input}")]
)
# input variables only lists required variables
assert set(prompt_optional.input_variables) == {"input"}
prompt_optional.input_schema(input="") # won't raise error
assert _normalize_schema(prompt_optional.get_input_jsonschema()) == snapshot(
name="partial"
)

if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) >= (2, 10):
assert _normalize_schema(prompt_optional.get_input_jsonschema()) == snapshot(
name="partial"
)


def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> None:
Expand Down
Loading
Loading