Skip to content

Commit

Permalink
langchain[minor]: add EnsembleRetriever (#5556)
Browse files Browse the repository at this point in the history
* langchain[patch]: add support to merge retrievers

* Format

* Parallelize, lint, format, small fixes

* Add entrypoint

* Fix import

* Add docs and fix build artifacts

---------

Co-authored-by: jeasonnow <[email protected]>
Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
3 people authored Jun 4, 2024
1 parent d35d12d commit 294f600
Show file tree
Hide file tree
Showing 17 changed files with 349 additions and 2 deletions.
5 changes: 3 additions & 2 deletions docs/core_docs/docs/concepts.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -672,8 +672,9 @@ LangChain provides several advanced retrieval types. A full list is below, along
| [Multi Vector](/docs/how_to/multi_vector/) | Vectorstore + Document Store | Sometimes during indexing | If you are able to extract information from documents that you think is more relevant to index than the text itself. | This involves creating multiple vectors for each document. Each vector could be created in a myriad of ways - examples include summaries of the text and hypothetical questions. |
| [Self Query](/docs/how_to/self_query/) | Vectorstore | Yes | If users are asking questions that are better answered by fetching documents based on metadata rather than similarity with the text. | This uses an LLM to transform user input into two things: (1) a string to look up semantically, (2) a metadata filer to go along with it. This is useful because oftentimes questions are about the METADATA of documents (not the content itself). |
| [Contextual Compression](/docs/how_to/contextual_compression/) | Any | Sometimes | If you are finding that your retrieved documents contain too much irrelevant information and are distracting the LLM. | This puts a post-processing step on top of another retriever and extracts only the most relevant information from retrieved documents. This can be done with embeddings or an LLM. |
| [Time-Weighted Vectorstore](/docs/how_to/time_weighted_vectorstore/) | Vectorstore | No | If you have timestamps associated with your documents, and you want to retrieve the most recent ones | This fetches documents based on a combination of semantic similarity (as in normal vector retrieval) and recency (looking at timestamps of indexed documents) |
| [Multi-Query Retriever](/docs/how_to/multiple_queries/) | Any | Yes | If users are asking questions that are complex and require multiple pieces of distinct information to respond | This uses an LLM to generate multiple queries from the original one. This is useful when the original query needs pieces of information about multiple topics to be properly answered. By generating multiple queries, we can then fetch documents for each of them. |
| [Time-Weighted Vectorstore](/docs/how_to/time_weighted_vectorstore/) | Vectorstore | No | If you have timestamps associated with your documents, and you want to retrieve the most recent ones. | This fetches documents based on a combination of semantic similarity (as in normal vector retrieval) and recency (looking at timestamps of indexed documents) |
| [Multi-Query Retriever](/docs/how_to/multiple_queries/) | Any | Yes | If users are asking questions that are complex and require multiple pieces of distinct information to respond. | This uses an LLM to generate multiple queries from the original one. This is useful when the original query needs pieces of information about multiple topics to be properly answered. By generating multiple queries, we can then fetch documents for each of them. |
| [Ensemble](/docs/how_to/ensemble_retriever) | Any | No | If you have multiple retrieval methods and want to try combining them. | This fetches documents from multiple retrievers and then combines them. |

### Text splitting

Expand Down
29 changes: 29 additions & 0 deletions docs/core_docs/docs/how_to/ensemble_retriever.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# How to combine results from multiple retrievers

:::info Prerequisites

This guide assumes familiarity with the following concepts:

- [Documents](/docs/concepts#document)
- [Retrievers](/docs/concepts#retrievers)

:::

The [EnsembleRetriever](https://api.js.langchain.com/classes/langchain_retrievers_ensemble.EnsembleRetriever.html) supports ensembling of results from multiple retrievers. It is initialized with a list of [BaseRetriever](https://api.js.langchain.com/classes/langchain_core_retrievers.BaseRetriever.html) objects. EnsembleRetrievers rerank the results of the constituent retrievers based on the [Reciprocal Rank Fusion](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) algorithm.

By leveraging the strengths of different algorithms, the `EnsembleRetriever` can achieve better performance than any single algorithm.

One useful pattern is to combine a keyword matching retriever with a dense retriever (like embedding similarity), because their strengths are complementary. This can be considered a form of "hybrid search". The sparse retriever is good at finding relevant documents based on keywords, while the dense retriever is good at finding relevant documents based on semantic similarity.

Below we demonstrate ensembling of a [simple custom retriever](/docs/how_to/custom_retriever/) that simply returns documents that directly contain the input query with a retriever derived from a [demo, in-memory, vector store](https://api.js.langchain.com/classes/langchain_vectorstores_memory.MemoryVectorStore.html).

import CodeBlock from "@theme/CodeBlock";
import Example from "@examples/retrievers/ensemble_retriever.ts";

<CodeBlock language="typescript">{Example}</CodeBlock>

## Next steps

You've now learned how to combine results from multiple retrievers.
Next, check out some other retrieval how-to guides, such as how to [improve results using multiple embeddings per document](/docs/how_to/multi_vector)
or how to [create your own custom retriever](/docs/how_to/custom_retriever).
1 change: 1 addition & 0 deletions docs/core_docs/docs/how_to/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ Retrievers are responsible for taking a query and returning relevant documents.
- [How to: generate multiple queries to retrieve data for](/docs/how_to/multiple_queries)
- [How to: use contextual compression to compress the data retrieved](/docs/how_to/contextual_compression)
- [How to: write a custom retriever class](/docs/how_to/custom_retriever)
- [How to: combine the results from multiple retrievers](/docs/how_to/ensemble_retriever)
- [How to: generate multiple embeddings per document](/docs/how_to/multi_vector)
- [How to: retrieve the whole document for a chunk](/docs/how_to/parent_document_retriever)
- [How to: generate metadata filters](/docs/how_to/self_query)
Expand Down
1 change: 1 addition & 0 deletions environment_tests/test-exports-bun/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export * from "langchain/callbacks";
export * from "langchain/output_parsers";
export * from "langchain/retrievers/contextual_compression";
export * from "langchain/retrievers/document_compressors";
export * from "langchain/retrievers/ensemble";
export * from "langchain/retrievers/multi_query";
export * from "langchain/retrievers/multi_vector";
export * from "langchain/retrievers/parent_document";
Expand Down
1 change: 1 addition & 0 deletions environment_tests/test-exports-cf/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export * from "langchain/callbacks";
export * from "langchain/output_parsers";
export * from "langchain/retrievers/contextual_compression";
export * from "langchain/retrievers/document_compressors";
export * from "langchain/retrievers/ensemble";
export * from "langchain/retrievers/multi_query";
export * from "langchain/retrievers/multi_vector";
export * from "langchain/retrievers/parent_document";
Expand Down
1 change: 1 addition & 0 deletions environment_tests/test-exports-cjs/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ const callbacks = require("langchain/callbacks");
const output_parsers = require("langchain/output_parsers");
const retrievers_contextual_compression = require("langchain/retrievers/contextual_compression");
const retrievers_document_compressors = require("langchain/retrievers/document_compressors");
const retrievers_ensemble = require("langchain/retrievers/ensemble");
const retrievers_multi_query = require("langchain/retrievers/multi_query");
const retrievers_multi_vector = require("langchain/retrievers/multi_vector");
const retrievers_parent_document = require("langchain/retrievers/parent_document");
Expand Down
1 change: 1 addition & 0 deletions environment_tests/test-exports-esbuild/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import * as callbacks from "langchain/callbacks";
import * as output_parsers from "langchain/output_parsers";
import * as retrievers_contextual_compression from "langchain/retrievers/contextual_compression";
import * as retrievers_document_compressors from "langchain/retrievers/document_compressors";
import * as retrievers_ensemble from "langchain/retrievers/ensemble";
import * as retrievers_multi_query from "langchain/retrievers/multi_query";
import * as retrievers_multi_vector from "langchain/retrievers/multi_vector";
import * as retrievers_parent_document from "langchain/retrievers/parent_document";
Expand Down
1 change: 1 addition & 0 deletions environment_tests/test-exports-esm/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import * as callbacks from "langchain/callbacks";
import * as output_parsers from "langchain/output_parsers";
import * as retrievers_contextual_compression from "langchain/retrievers/contextual_compression";
import * as retrievers_document_compressors from "langchain/retrievers/document_compressors";
import * as retrievers_ensemble from "langchain/retrievers/ensemble";
import * as retrievers_multi_query from "langchain/retrievers/multi_query";
import * as retrievers_multi_vector from "langchain/retrievers/multi_vector";
import * as retrievers_parent_document from "langchain/retrievers/parent_document";
Expand Down
1 change: 1 addition & 0 deletions environment_tests/test-exports-vercel/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export * from "langchain/callbacks";
export * from "langchain/output_parsers";
export * from "langchain/retrievers/contextual_compression";
export * from "langchain/retrievers/document_compressors";
export * from "langchain/retrievers/ensemble";
export * from "langchain/retrievers/multi_query";
export * from "langchain/retrievers/multi_vector";
export * from "langchain/retrievers/parent_document";
Expand Down
1 change: 1 addition & 0 deletions environment_tests/test-exports-vite/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export * from "langchain/callbacks";
export * from "langchain/output_parsers";
export * from "langchain/retrievers/contextual_compression";
export * from "langchain/retrievers/document_compressors";
export * from "langchain/retrievers/ensemble";
export * from "langchain/retrievers/multi_query";
export * from "langchain/retrievers/multi_vector";
export * from "langchain/retrievers/parent_document";
Expand Down
67 changes: 67 additions & 0 deletions examples/src/retrievers/ensemble_retriever.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { EnsembleRetriever } from "langchain/retrievers/ensemble";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { OpenAIEmbeddings } from "@langchain/openai";
import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers";
import { Document } from "@langchain/core/documents";

class SimpleCustomRetriever extends BaseRetriever {
lc_namespace = [];

documents: Document[];

constructor(fields: { documents: Document[] } & BaseRetrieverInput) {
super(fields);
this.documents = fields.documents;
}

async _getRelevantDocuments(query: string): Promise<Document[]> {
return this.documents.filter((document) =>
document.pageContent.includes(query)
);
}
}

const docs1 = [
new Document({ pageContent: "I like apples", metadata: { source: 1 } }),
new Document({ pageContent: "I like oranges", metadata: { source: 1 } }),
new Document({
pageContent: "apples and oranges are fruits",
metadata: { source: 1 },
}),
];

const keywordRetriever = new SimpleCustomRetriever({ documents: docs1 });

const docs2 = [
new Document({ pageContent: "You like apples", metadata: { source: 2 } }),
new Document({ pageContent: "You like oranges", metadata: { source: 2 } }),
];

const vectorstore = await MemoryVectorStore.fromDocuments(
docs2,
new OpenAIEmbeddings()
);

const vectorstoreRetriever = vectorstore.asRetriever();

const retriever = new EnsembleRetriever({
retrievers: [vectorstoreRetriever, keywordRetriever],
weights: [0.5, 0.5],
});

const query = "apples";
const retrievedDocs = await retriever.invoke(query);

console.log(retrievedDocs);

/*
[
Document { pageContent: 'You like apples', metadata: { source: 2 } },
Document { pageContent: 'I like apples', metadata: { source: 1 } },
Document { pageContent: 'You like oranges', metadata: { source: 2 } },
Document {
pageContent: 'apples and oranges are fruits',
metadata: { source: 1 }
}
]
*/
4 changes: 4 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ retrievers/document_compressors.cjs
retrievers/document_compressors.js
retrievers/document_compressors.d.ts
retrievers/document_compressors.d.cts
retrievers/ensemble.cjs
retrievers/ensemble.js
retrievers/ensemble.d.ts
retrievers/ensemble.d.cts
retrievers/multi_query.cjs
retrievers/multi_query.js
retrievers/multi_query.d.ts
Expand Down
1 change: 1 addition & 0 deletions langchain/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ export const config = {
// retrievers
"retrievers/contextual_compression": "retrievers/contextual_compression",
"retrievers/document_compressors": "retrievers/document_compressors/index",
"retrievers/ensemble": "retrievers/ensemble",
"retrievers/multi_query": "retrievers/multi_query",
"retrievers/multi_vector": "retrievers/multi_vector",
"retrievers/parent_document": "retrievers/parent_document",
Expand Down
13 changes: 13 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@
"retrievers/document_compressors.js",
"retrievers/document_compressors.d.ts",
"retrievers/document_compressors.d.cts",
"retrievers/ensemble.cjs",
"retrievers/ensemble.js",
"retrievers/ensemble.d.ts",
"retrievers/ensemble.d.cts",
"retrievers/multi_query.cjs",
"retrievers/multi_query.js",
"retrievers/multi_query.d.ts",
Expand Down Expand Up @@ -1725,6 +1729,15 @@
"import": "./retrievers/document_compressors.js",
"require": "./retrievers/document_compressors.cjs"
},
"./retrievers/ensemble": {
"types": {
"import": "./retrievers/ensemble.d.ts",
"require": "./retrievers/ensemble.d.cts",
"default": "./retrievers/ensemble.d.ts"
},
"import": "./retrievers/ensemble.js",
"require": "./retrievers/ensemble.cjs"
},
"./retrievers/multi_query": {
"types": {
"import": "./retrievers/multi_query.d.ts",
Expand Down
1 change: 1 addition & 0 deletions langchain/src/load/import_map.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export * as callbacks from "../callbacks/index.js";
export * as output_parsers from "../output_parsers/index.js";
export * as retrievers__contextual_compression from "../retrievers/contextual_compression.js";
export * as retrievers__document_compressors from "../retrievers/document_compressors/index.js";
export * as retrievers__ensemble from "../retrievers/ensemble.js";
export * as retrievers__multi_query from "../retrievers/multi_query.js";
export * as retrievers__multi_vector from "../retrievers/multi_vector.js";
export * as retrievers__parent_document from "../retrievers/parent_document.js";
Expand Down
119 changes: 119 additions & 0 deletions langchain/src/retrievers/ensemble.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers";
import { Document, DocumentInterface } from "@langchain/core/documents";
import { CallbackManagerForRetrieverRun } from "@langchain/core/callbacks/manager";

export interface EnsembleRetrieverInput extends BaseRetrieverInput {
/** A list of retrievers to ensemble. */
retrievers: BaseRetriever[];
/**
* A list of weights corresponding to the retrievers. Defaults to equal
* weighting for all retrievers.
*/
weights?: number[];
/**
* A constant added to the rank, controlling the balance between the importance
* of high-ranked items and the consideration given to lower-ranked items.
* Default is 60.
*/
c?: number;
}

/**
* Ensemble retriever that aggregates and orders the results of
* multiple retrievers by using weighted Reciprocal Rank Fusion.
*/
export class EnsembleRetriever extends BaseRetriever {
static lc_name() {
return "EnsembleRetriever";
}

lc_namespace = ["langchain", "retrievers", "ensemble_retriever"];

retrievers: BaseRetriever[];

weights: number[];

c = 60;

constructor(args: EnsembleRetrieverInput) {
super(args);
this.retrievers = args.retrievers;
this.weights =
args.weights ||
new Array(args.retrievers.length).fill(1 / args.retrievers.length);
this.c = args.c || 60;
}

async _getRelevantDocuments(
query: string,
runManager?: CallbackManagerForRetrieverRun
) {
return this._rankFusion(query, runManager);
}

async _rankFusion(
query: string,
runManager?: CallbackManagerForRetrieverRun
) {
const retrieverDocs = await Promise.all(
this.retrievers.map((retriever, i) =>
retriever.invoke(query, {
callbacks: runManager?.getChild(`retriever_${i + 1}`),
})
)
);

const fusedDocs = await this._weightedReciprocalRank(retrieverDocs);
return fusedDocs;
}

async _weightedReciprocalRank(docList: DocumentInterface[][]) {
if (docList.length !== this.weights.length) {
throw new Error(
"Number of retrieved document lists must be equal to the number of weights."
);
}

const rrfScoreDict = docList.reduce(
(rffScore: Record<string, number>, retrieverDoc, idx) => {
let rank = 1;
const weight = this.weights[idx];
while (rank <= retrieverDoc.length) {
const { pageContent } = retrieverDoc[rank - 1];
if (!rffScore[pageContent]) {
// eslint-disable-next-line no-param-reassign
rffScore[pageContent] = 0;
}
// eslint-disable-next-line no-param-reassign
rffScore[pageContent] += weight / (rank + this.c);
rank += 1;
}

return rffScore;
},
{}
);

const uniqueDocs = this._uniqueUnion(docList.flat());
const sortedDocs = Array.from(uniqueDocs).sort(
(a, b) => rrfScoreDict[b.pageContent] - rrfScoreDict[a.pageContent]
);

return sortedDocs;
}

private _uniqueUnion(documents: Document[]): Document[] {
const documentSet = new Set();
const result = [];

for (const doc of documents) {
const key = doc.pageContent;
if (!documentSet.has(key)) {
documentSet.add(key);
result.push(doc);
}
}

return result;
}
}
Loading

0 comments on commit 294f600

Please sign in to comment.