Skip to content

Commit

Permalink
langchain[patch]: Compat with pydantic 2.10 (#28307)
Browse files Browse the repository at this point in the history
pydantic compat 2.10 for langchain
  • Loading branch information
eyurtsev authored Nov 23, 2024
1 parent a813d11 commit 563587e
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 115 deletions.
4 changes: 3 additions & 1 deletion libs/langchain/langchain/chains/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ async def acall_model(state: ChainState, config: RunnableConfig):
api_docs: str
question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private:
limit_to_domains: Optional[Sequence[str]] = Field(default_factory=list)
limit_to_domains: Optional[Sequence[str]] = Field(
default_factory=list # type: ignore
)
"""Use to limit the domains that can be accessed by the API chain.
* For example, to limit to just the domain `https://www.example.com`, set
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class OpenAIModerationChain(Chain):
output_key: str = "output" #: :meta private:
openai_api_key: Optional[str] = None
openai_organization: Optional[str] = None
openai_pre_1_0: bool = Field(default=None)
openai_pre_1_0: bool = Field(default=False)

@model_validator(mode="before")
@classmethod
Expand Down
5 changes: 5 additions & 0 deletions libs/langchain/langchain/memory/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, Dict, List, Type

from langchain_core._api import deprecated
from langchain_core.caches import BaseCache as BaseCache # For model_rebuild
from langchain_core.callbacks import Callbacks as Callbacks # For model_rebuild
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
Expand Down Expand Up @@ -131,3 +133,6 @@ def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.buffer = ""


ConversationSummaryMemory.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class ConversationVectorStoreTokenBufferMemory(ConversationTokenBufferMemory):
previous_history_template: str = DEFAULT_HISTORY_TEMPLATE
split_chunk_size: int = 1000

_memory_retriever: VectorStoreRetrieverMemory = PrivateAttr(default=None)
_memory_retriever: VectorStoreRetrieverMemory = PrivateAttr(default=None) # type: ignore
_timestamps: List[datetime] = PrivateAttr(default_factory=list)

@property
Expand Down
7 changes: 4 additions & 3 deletions libs/langchain/langchain/output_parsers/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ class OutputFixingParser(BaseOutputParser[T]):
def is_lc_serializable(cls) -> bool:
return True

parser: Annotated[BaseOutputParser[T], SkipValidation()]
parser: Annotated[Any, SkipValidation()]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Union[
RunnableSerializable[OutputFixingParserRetryChainInput, str], Any
retry_chain: Annotated[
Union[RunnableSerializable[OutputFixingParserRetryChainInput, str], Any],
SkipValidation(),
]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
Expand Down
12 changes: 9 additions & 3 deletions libs/langchain/langchain/output_parsers/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ class RetryOutputParser(BaseOutputParser[T]):
parser: Annotated[BaseOutputParser[T], SkipValidation()]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Union[RunnableSerializable[RetryOutputParserRetryChainInput, str], Any]
retry_chain: Annotated[
Union[RunnableSerializable[RetryOutputParserRetryChainInput, str], Any],
SkipValidation(),
]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
Expand Down Expand Up @@ -187,8 +190,11 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
parser: Annotated[BaseOutputParser[T], SkipValidation()]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Union[
RunnableSerializable[RetryWithErrorOutputParserRetryChainInput, str], Any
retry_chain: Annotated[
Union[
RunnableSerializable[RetryWithErrorOutputParserRetryChainInput, str], Any
],
SkipValidation(),
]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
Expand Down
253 changes: 154 additions & 99 deletions libs/langchain/poetry.lock

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions libs/langchain/tests/unit_tests/output_parsers/test_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_output_fixing_parser_parse(
base_parser.attemp_count_before_success
) # Success on the (n+1)-th attempt # noqa
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = OutputFixingParser(
parser = OutputFixingParser[str](
parser=base_parser,
max_retries=n, # n times to retry, that is, (n+1) times call
retry_chain=RunnablePassthrough(),
Expand Down Expand Up @@ -94,7 +94,7 @@ async def test_output_fixing_parser_aparse(
base_parser.attemp_count_before_success
) # Success on the (n+1)-th attempt # noqa
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = OutputFixingParser(
parser = OutputFixingParser[str](
parser=base_parser,
max_retries=n, # n times to retry, that is, (n+1) times call
retry_chain=RunnablePassthrough(),
Expand All @@ -108,7 +108,7 @@ async def test_output_fixing_parser_aparse(
def test_output_fixing_parser_parse_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = OutputFixingParser(
parser = OutputFixingParser[str](
parser=base_parser,
max_retries=n - 1, # n-1 times to retry, that is, n times call
retry_chain=RunnablePassthrough(),
Expand All @@ -122,7 +122,7 @@ def test_output_fixing_parser_parse_fail() -> None:
async def test_output_fixing_parser_aparse_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = OutputFixingParser(
parser = OutputFixingParser[str](
parser=base_parser,
max_retries=n - 1, # n-1 times to retry, that is, n times call
retry_chain=RunnablePassthrough(),
Expand All @@ -143,7 +143,9 @@ async def test_output_fixing_parser_aparse_fail() -> None:
def test_output_fixing_parser_output_type(
base_parser: BaseOutputParser,
) -> None:
parser = OutputFixingParser(parser=base_parser, retry_chain=RunnablePassthrough())
parser = OutputFixingParser[str](
parser=base_parser, retry_chain=RunnablePassthrough()
)
assert parser.OutputType is base_parser.OutputType


Expand Down Expand Up @@ -176,7 +178,7 @@ def test_output_fixing_parser_parse_with_retry_chain(
instructions = base_parser.get_format_instructions()
object.__setattr__(base_parser, "get_format_instructions", lambda: instructions)
# test
parser = OutputFixingParser(
parser = OutputFixingParser[str](
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
Expand Down Expand Up @@ -212,7 +214,7 @@ async def test_output_fixing_parser_aparse_with_retry_chain(
instructions = base_parser.get_format_instructions()
object.__setattr__(base_parser, "get_format_instructions", lambda: instructions)
# test
parser = OutputFixingParser(
parser = OutputFixingParser[str](
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
Expand Down

0 comments on commit 563587e

Please sign in to comment.