From 452671b464f6e95706fcdbd34569aa49b6874c4d Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Mon, 2 Dec 2024 16:02:04 -0800 Subject: [PATCH] core, tests: more tolerant _aget_relevant_documents function --- libs/core/langchain_core/retrievers.py | 14 +++++++++++++- libs/core/tests/unit_tests/test_retrievers.py | 0 .../tests/unit_tests/test_basic_retriever.py | 5 +---- 3 files changed, 14 insertions(+), 5 deletions(-) create mode 100644 libs/core/tests/unit_tests/test_retrievers.py diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index 7462569ddd30d..aaa066e82f0e5 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -24,7 +24,7 @@ import warnings from abc import ABC, abstractmethod from inspect import signature -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Self from pydantic import ConfigDict from typing_extensions import TypedDict @@ -180,6 +180,18 @@ def __init_subclass__(cls, **kwargs: Any) -> None: cls._aget_relevant_documents = aswap # type: ignore[assignment] parameters = signature(cls._get_relevant_documents).parameters cls._new_arg_supported = parameters.get("run_manager") is not None + if ( + not cls._new_arg_supported + and cls._aget_relevant_documents == BaseRetriever._aget_relevant_documents + ): + # we need to tolerate no run_manager in _aget_relevant_documents signature + async def _aget_relevant_documents( + self: Self, query: str + ) -> list[Document]: + return await run_in_executor(None, self._get_relevant_documents, query) # type: ignore + + cls._aget_relevant_documents = _aget_relevant_documents # type: ignore[assignment] + # If a V1 retriever broke the interface and expects additional arguments cls._expects_other_args = ( len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0 diff --git a/libs/core/tests/unit_tests/test_retrievers.py b/libs/core/tests/unit_tests/test_retrievers.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/standard-tests/tests/unit_tests/test_basic_retriever.py b/libs/standard-tests/tests/unit_tests/test_basic_retriever.py index af5d598c722f2..fb7999a09fb5c 100644 --- a/libs/standard-tests/tests/unit_tests/test_basic_retriever.py +++ b/libs/standard-tests/tests/unit_tests/test_basic_retriever.py @@ -1,6 +1,5 @@ from typing import Any, Type -from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever @@ -11,9 +10,7 @@ class ParrotRetriever(BaseRetriever): parrot_name: str k: int = 3 - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any - ) -> list[Document]: + def _get_relevant_documents(self, query: str, **kwargs: Any) -> list[Document]: k = kwargs.get("k", self.k) return [Document(page_content=f"{self.parrot_name} says: {query}")] * k