Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain[patch]: Compat with pydantic 2.10 #28307

Merged
merged 18 commits into from
Nov 23, 2024
Merged
4 changes: 3 additions & 1 deletion libs/langchain/langchain/chains/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ async def acall_model(state: ChainState, config: RunnableConfig):
api_docs: str
question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private:
limit_to_domains: Optional[Sequence[str]] = Field(default_factory=list)
limit_to_domains: Optional[Sequence[str]] = Field(
default_factory=list # type: ignore
)
"""Use to limit the domains that can be accessed by the API chain.

* For example, to limit to just the domain `https://www.example.com`, set
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class OpenAIModerationChain(Chain):
output_key: str = "output" #: :meta private:
openai_api_key: Optional[str] = None
openai_organization: Optional[str] = None
openai_pre_1_0: bool = Field(default=None)
openai_pre_1_0: bool = Field(default=False)

@model_validator(mode="before")
@classmethod
Expand Down
5 changes: 5 additions & 0 deletions libs/langchain/langchain/memory/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, Dict, List, Type

from langchain_core._api import deprecated
from langchain_core.caches import BaseCache as BaseCache # For model_rebuild
from langchain_core.callbacks import Callbacks as Callbacks # For model_rebuild
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
Expand Down Expand Up @@ -131,3 +133,6 @@ def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.buffer = ""


ConversationSummaryMemory.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class ConversationVectorStoreTokenBufferMemory(ConversationTokenBufferMemory):
previous_history_template: str = DEFAULT_HISTORY_TEMPLATE
split_chunk_size: int = 1000

_memory_retriever: VectorStoreRetrieverMemory = PrivateAttr(default=None)
_memory_retriever: VectorStoreRetrieverMemory = PrivateAttr(default=None) # type: ignore
_timestamps: List[datetime] = PrivateAttr(default_factory=list)

@property
Expand Down
7 changes: 4 additions & 3 deletions libs/langchain/langchain/output_parsers/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ class OutputFixingParser(BaseOutputParser[T]):
def is_lc_serializable(cls) -> bool:
return True

parser: Annotated[BaseOutputParser[T], SkipValidation()]
parser: Annotated[Any, SkipValidation()]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Union[
RunnableSerializable[OutputFixingParserRetryChainInput, str], Any
retry_chain: Annotated[
Union[RunnableSerializable[OutputFixingParserRetryChainInput, str], Any],
SkipValidation(),
]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
Expand Down
12 changes: 9 additions & 3 deletions libs/langchain/langchain/output_parsers/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ class RetryOutputParser(BaseOutputParser[T]):
parser: Annotated[BaseOutputParser[T], SkipValidation()]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Union[RunnableSerializable[RetryOutputParserRetryChainInput, str], Any]
retry_chain: Annotated[
Union[RunnableSerializable[RetryOutputParserRetryChainInput, str], Any],
SkipValidation(),
]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
Expand Down Expand Up @@ -187,8 +190,11 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
parser: Annotated[BaseOutputParser[T], SkipValidation()]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Union[
RunnableSerializable[RetryWithErrorOutputParserRetryChainInput, str], Any
retry_chain: Annotated[
Union[
RunnableSerializable[RetryWithErrorOutputParserRetryChainInput, str], Any
],
SkipValidation(),
]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
Expand Down
253 changes: 154 additions & 99 deletions libs/langchain/poetry.lock

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions libs/langchain/tests/unit_tests/output_parsers/test_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_output_fixing_parser_parse(
base_parser.attemp_count_before_success
) # Success on the (n+1)-th attempt # noqa
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = OutputFixingParser(
parser = OutputFixingParser[str](
parser=base_parser,
max_retries=n, # n times to retry, that is, (n+1) times call
retry_chain=RunnablePassthrough(),
Expand Down Expand Up @@ -94,7 +94,7 @@ async def test_output_fixing_parser_aparse(
base_parser.attemp_count_before_success
) # Success on the (n+1)-th attempt # noqa
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = OutputFixingParser(
parser = OutputFixingParser[str](
parser=base_parser,
max_retries=n, # n times to retry, that is, (n+1) times call
retry_chain=RunnablePassthrough(),
Expand All @@ -108,7 +108,7 @@ async def test_output_fixing_parser_aparse(
def test_output_fixing_parser_parse_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = OutputFixingParser(
parser = OutputFixingParser[str](
parser=base_parser,
max_retries=n - 1, # n-1 times to retry, that is, n times call
retry_chain=RunnablePassthrough(),
Expand All @@ -122,7 +122,7 @@ def test_output_fixing_parser_parse_fail() -> None:
async def test_output_fixing_parser_aparse_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = OutputFixingParser(
parser = OutputFixingParser[str](
parser=base_parser,
max_retries=n - 1, # n-1 times to retry, that is, n times call
retry_chain=RunnablePassthrough(),
Expand All @@ -143,7 +143,9 @@ async def test_output_fixing_parser_aparse_fail() -> None:
def test_output_fixing_parser_output_type(
base_parser: BaseOutputParser,
) -> None:
parser = OutputFixingParser(parser=base_parser, retry_chain=RunnablePassthrough())
parser = OutputFixingParser[str](
parser=base_parser, retry_chain=RunnablePassthrough()
)
assert parser.OutputType is base_parser.OutputType


Expand Down Expand Up @@ -176,7 +178,7 @@ def test_output_fixing_parser_parse_with_retry_chain(
instructions = base_parser.get_format_instructions()
object.__setattr__(base_parser, "get_format_instructions", lambda: instructions)
# test
parser = OutputFixingParser(
parser = OutputFixingParser[str](
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
Expand Down Expand Up @@ -212,7 +214,7 @@ async def test_output_fixing_parser_aparse_with_retry_chain(
instructions = base_parser.get_format_instructions()
object.__setattr__(base_parser, "get_format_instructions", lambda: instructions)
# test
parser = OutputFixingParser(
parser = OutputFixingParser[str](
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
Expand Down
Loading