Skip to content

Commit

Permalink
core, tests: more tolerant _aget_relevant_documents function
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis committed Dec 3, 2024
1 parent 000be1f commit 452671b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
14 changes: 13 additions & 1 deletion libs/core/langchain_core/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Self

from pydantic import ConfigDict
from typing_extensions import TypedDict
Expand Down Expand Up @@ -180,6 +180,18 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
cls._aget_relevant_documents = aswap # type: ignore[assignment]
parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
if (
not cls._new_arg_supported
and cls._aget_relevant_documents == BaseRetriever._aget_relevant_documents
):
# we need to tolerate no run_manager in _aget_relevant_documents signature
async def _aget_relevant_documents(
self: Self, query: str
) -> list[Document]:
return await run_in_executor(None, self._get_relevant_documents, query) # type: ignore

cls._aget_relevant_documents = _aget_relevant_documents # type: ignore[assignment]

# If a V1 retriever broke the interface and expects additional arguments
cls._expects_other_args = (
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
Expand Down
Empty file.
5 changes: 1 addition & 4 deletions libs/standard-tests/tests/unit_tests/test_basic_retriever.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Type

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

Expand All @@ -11,9 +10,7 @@ class ParrotRetriever(BaseRetriever):
parrot_name: str
k: int = 3

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> list[Document]:
def _get_relevant_documents(self, query: str, **kwargs: Any) -> list[Document]:
k = kwargs.get("k", self.k)
return [Document(page_content=f"{self.parrot_name} says: {query}")] * k

Expand Down

0 comments on commit 452671b

Please sign in to comment.