-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
langchain[minor]: add EnsembleRetriever (#5556)
* langchain[patch]: add support to merge retrievers * Format * Parallelize, lint, format, small fixes * Add entrypoint * Fix import * Add docs and fix build artifacts --------- Co-authored-by: jeasonnow <[email protected]> Co-authored-by: jacoblee93 <[email protected]>
- Loading branch information
1 parent
d35d12d
commit 294f600
Showing
17 changed files
with
349 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# How to combine results from multiple retrievers | ||
|
||
:::info Prerequisites | ||
|
||
This guide assumes familiarity with the following concepts: | ||
|
||
- [Documents](/docs/concepts#document) | ||
- [Retrievers](/docs/concepts#retrievers) | ||
|
||
::: | ||
|
||
The [EnsembleRetriever](https://api.js.langchain.com/classes/langchain_retrievers_ensemble.EnsembleRetriever.html) supports ensembling of results from multiple retrievers. It is initialized with a list of [BaseRetriever](https://api.js.langchain.com/classes/langchain_core_retrievers.BaseRetriever.html) objects. EnsembleRetrievers rerank the results of the constituent retrievers based on the [Reciprocal Rank Fusion](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) algorithm. | ||
|
||
By leveraging the strengths of different algorithms, the `EnsembleRetriever` can achieve better performance than any single algorithm. | ||
|
||
One useful pattern is to combine a keyword matching retriever with a dense retriever (like embedding similarity), because their strengths are complementary. This can be considered a form of "hybrid search". The sparse retriever is good at finding relevant documents based on keywords, while the dense retriever is good at finding relevant documents based on semantic similarity. | ||
|
||
Below we demonstrate ensembling of a [simple custom retriever](/docs/how_to/custom_retriever/) that simply returns documents that directly contain the input query with a retriever derived from a [demo, in-memory, vector store](https://api.js.langchain.com/classes/langchain_vectorstores_memory.MemoryVectorStore.html). | ||
|
||
import CodeBlock from "@theme/CodeBlock"; | ||
import Example from "@examples/retrievers/ensemble_retriever.ts"; | ||
|
||
<CodeBlock language="typescript">{Example}</CodeBlock> | ||
|
||
## Next steps | ||
|
||
You've now learned how to combine results from multiple retrievers. | ||
Next, check out some other retrieval how-to guides, such as how to [improve results using multiple embeddings per document](/docs/how_to/multi_vector) | ||
or how to [create your own custom retriever](/docs/how_to/custom_retriever). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import { EnsembleRetriever } from "langchain/retrievers/ensemble"; | ||
import { MemoryVectorStore } from "langchain/vectorstores/memory"; | ||
import { OpenAIEmbeddings } from "@langchain/openai"; | ||
import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers"; | ||
import { Document } from "@langchain/core/documents"; | ||
|
||
class SimpleCustomRetriever extends BaseRetriever { | ||
lc_namespace = []; | ||
|
||
documents: Document[]; | ||
|
||
constructor(fields: { documents: Document[] } & BaseRetrieverInput) { | ||
super(fields); | ||
this.documents = fields.documents; | ||
} | ||
|
||
async _getRelevantDocuments(query: string): Promise<Document[]> { | ||
return this.documents.filter((document) => | ||
document.pageContent.includes(query) | ||
); | ||
} | ||
} | ||
|
||
const docs1 = [ | ||
new Document({ pageContent: "I like apples", metadata: { source: 1 } }), | ||
new Document({ pageContent: "I like oranges", metadata: { source: 1 } }), | ||
new Document({ | ||
pageContent: "apples and oranges are fruits", | ||
metadata: { source: 1 }, | ||
}), | ||
]; | ||
|
||
const keywordRetriever = new SimpleCustomRetriever({ documents: docs1 }); | ||
|
||
const docs2 = [ | ||
new Document({ pageContent: "You like apples", metadata: { source: 2 } }), | ||
new Document({ pageContent: "You like oranges", metadata: { source: 2 } }), | ||
]; | ||
|
||
const vectorstore = await MemoryVectorStore.fromDocuments( | ||
docs2, | ||
new OpenAIEmbeddings() | ||
); | ||
|
||
const vectorstoreRetriever = vectorstore.asRetriever(); | ||
|
||
const retriever = new EnsembleRetriever({ | ||
retrievers: [vectorstoreRetriever, keywordRetriever], | ||
weights: [0.5, 0.5], | ||
}); | ||
|
||
const query = "apples"; | ||
const retrievedDocs = await retriever.invoke(query); | ||
|
||
console.log(retrievedDocs); | ||
|
||
/* | ||
[ | ||
Document { pageContent: 'You like apples', metadata: { source: 2 } }, | ||
Document { pageContent: 'I like apples', metadata: { source: 1 } }, | ||
Document { pageContent: 'You like oranges', metadata: { source: 2 } }, | ||
Document { | ||
pageContent: 'apples and oranges are fruits', | ||
metadata: { source: 1 } | ||
} | ||
] | ||
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers"; | ||
import { Document, DocumentInterface } from "@langchain/core/documents"; | ||
import { CallbackManagerForRetrieverRun } from "@langchain/core/callbacks/manager"; | ||
|
||
export interface EnsembleRetrieverInput extends BaseRetrieverInput { | ||
/** A list of retrievers to ensemble. */ | ||
retrievers: BaseRetriever[]; | ||
/** | ||
* A list of weights corresponding to the retrievers. Defaults to equal | ||
* weighting for all retrievers. | ||
*/ | ||
weights?: number[]; | ||
/** | ||
* A constant added to the rank, controlling the balance between the importance | ||
* of high-ranked items and the consideration given to lower-ranked items. | ||
* Default is 60. | ||
*/ | ||
c?: number; | ||
} | ||
|
||
/** | ||
* Ensemble retriever that aggregates and orders the results of | ||
* multiple retrievers by using weighted Reciprocal Rank Fusion. | ||
*/ | ||
export class EnsembleRetriever extends BaseRetriever { | ||
static lc_name() { | ||
return "EnsembleRetriever"; | ||
} | ||
|
||
lc_namespace = ["langchain", "retrievers", "ensemble_retriever"]; | ||
|
||
retrievers: BaseRetriever[]; | ||
|
||
weights: number[]; | ||
|
||
c = 60; | ||
|
||
constructor(args: EnsembleRetrieverInput) { | ||
super(args); | ||
this.retrievers = args.retrievers; | ||
this.weights = | ||
args.weights || | ||
new Array(args.retrievers.length).fill(1 / args.retrievers.length); | ||
this.c = args.c || 60; | ||
} | ||
|
||
async _getRelevantDocuments( | ||
query: string, | ||
runManager?: CallbackManagerForRetrieverRun | ||
) { | ||
return this._rankFusion(query, runManager); | ||
} | ||
|
||
async _rankFusion( | ||
query: string, | ||
runManager?: CallbackManagerForRetrieverRun | ||
) { | ||
const retrieverDocs = await Promise.all( | ||
this.retrievers.map((retriever, i) => | ||
retriever.invoke(query, { | ||
callbacks: runManager?.getChild(`retriever_${i + 1}`), | ||
}) | ||
) | ||
); | ||
|
||
const fusedDocs = await this._weightedReciprocalRank(retrieverDocs); | ||
return fusedDocs; | ||
} | ||
|
||
async _weightedReciprocalRank(docList: DocumentInterface[][]) { | ||
if (docList.length !== this.weights.length) { | ||
throw new Error( | ||
"Number of retrieved document lists must be equal to the number of weights." | ||
); | ||
} | ||
|
||
const rrfScoreDict = docList.reduce( | ||
(rffScore: Record<string, number>, retrieverDoc, idx) => { | ||
let rank = 1; | ||
const weight = this.weights[idx]; | ||
while (rank <= retrieverDoc.length) { | ||
const { pageContent } = retrieverDoc[rank - 1]; | ||
if (!rffScore[pageContent]) { | ||
// eslint-disable-next-line no-param-reassign | ||
rffScore[pageContent] = 0; | ||
} | ||
// eslint-disable-next-line no-param-reassign | ||
rffScore[pageContent] += weight / (rank + this.c); | ||
rank += 1; | ||
} | ||
|
||
return rffScore; | ||
}, | ||
{} | ||
); | ||
|
||
const uniqueDocs = this._uniqueUnion(docList.flat()); | ||
const sortedDocs = Array.from(uniqueDocs).sort( | ||
(a, b) => rrfScoreDict[b.pageContent] - rrfScoreDict[a.pageContent] | ||
); | ||
|
||
return sortedDocs; | ||
} | ||
|
||
private _uniqueUnion(documents: Document[]): Document[] { | ||
const documentSet = new Set(); | ||
const result = []; | ||
|
||
for (const doc of documents) { | ||
const key = doc.pageContent; | ||
if (!documentSet.has(key)) { | ||
documentSet.add(key); | ||
result.push(doc); | ||
} | ||
} | ||
|
||
return result; | ||
} | ||
} |
Oops, something went wrong.