diff --git a/libs/langchain/langchain/chains/retrieval_qa/base.py b/libs/langchain/langchain/chains/retrieval_qa/base.py index 4828229d632f0..d6718030eb40a 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/base.py +++ b/libs/langchain/langchain/chains/retrieval_qa/base.py @@ -31,6 +31,7 @@ class BaseRetrievalQA(Chain): """Chain to use to combine the documents.""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: + documents_key: str = "from_documents" #: :meta private: return_source_documents: bool = False """Return the source documents or not.""" @@ -121,6 +122,9 @@ def _call( If chain has 'return_source_documents' as 'True', returns the retrieved documents as well under the key 'source_documents'. + If 'from_documents' is passed in the inputs, chain will use those + instead of calling _get_docs(). + Example: .. code-block:: python @@ -132,7 +136,13 @@ def _call( accepts_run_manager = ( "run_manager" in inspect.signature(self._get_docs).parameters ) - if accepts_run_manager: + from_documents = inputs.get(self.documents_key) + if from_documents: + docs = from_documents + for doc in docs: + if not isinstance(doc, Document): + raise TypeError(f"{doc} is not a Document") + elif accepts_run_manager: docs = self._get_docs(question, run_manager=_run_manager) else: docs = self._get_docs(question) # type: ignore[call-arg] @@ -164,6 +174,9 @@ async def _acall( If chain has 'return_source_documents' as 'True', returns the retrieved documents as well under the key 'source_documents'. + If 'from_documents' is passed in the inputs, chain will use those + instead of calling _get_docs(). + Example: .. code-block:: python @@ -175,7 +188,13 @@ async def _acall( accepts_run_manager = ( "run_manager" in inspect.signature(self._aget_docs).parameters ) - if accepts_run_manager: + from_documents = inputs.get(self.documents_key) + if from_documents: + docs = from_documents + for doc in docs: + if not isinstance(doc, Document): + raise TypeError(f"{doc} is not a Document") + elif accepts_run_manager: docs = await self._aget_docs(question, run_manager=_run_manager) else: docs = await self._aget_docs(question) # type: ignore[call-arg]