From 294f60096a31df513933c6452d340f451449cfa0 Mon Sep 17 00:00:00 2001 From: santree Date: Wed, 5 Jun 2024 06:01:19 +0800 Subject: [PATCH] langchain[minor]: add EnsembleRetriever (#5556) * langchain[patch]: add support to merge retrievers * Format * Parallelize, lint, format, small fixes * Add entrypoint * Fix import * Add docs and fix build artifacts --------- Co-authored-by: jeasonnow Co-authored-by: jacoblee93 --- docs/core_docs/docs/concepts.mdx | 5 +- .../docs/how_to/ensemble_retriever.mdx | 29 +++++ docs/core_docs/docs/how_to/index.mdx | 1 + .../test-exports-bun/src/entrypoints.js | 1 + .../test-exports-cf/src/entrypoints.js | 1 + .../test-exports-cjs/src/entrypoints.js | 1 + .../test-exports-esbuild/src/entrypoints.js | 1 + .../test-exports-esm/src/entrypoints.js | 1 + .../test-exports-vercel/src/entrypoints.js | 1 + .../test-exports-vite/src/entrypoints.js | 1 + examples/src/retrievers/ensemble_retriever.ts | 67 ++++++++++ langchain/.gitignore | 4 + langchain/langchain.config.js | 1 + langchain/package.json | 13 ++ langchain/src/load/import_map.ts | 1 + langchain/src/retrievers/ensemble.ts | 119 ++++++++++++++++++ .../tests/ensemble_retriever.int.test.ts | 104 +++++++++++++++ 17 files changed, 349 insertions(+), 2 deletions(-) create mode 100644 docs/core_docs/docs/how_to/ensemble_retriever.mdx create mode 100644 examples/src/retrievers/ensemble_retriever.ts create mode 100644 langchain/src/retrievers/ensemble.ts create mode 100644 langchain/src/retrievers/tests/ensemble_retriever.int.test.ts diff --git a/docs/core_docs/docs/concepts.mdx b/docs/core_docs/docs/concepts.mdx index 1114e0c63e89..fd02e5f694d5 100644 --- a/docs/core_docs/docs/concepts.mdx +++ b/docs/core_docs/docs/concepts.mdx @@ -672,8 +672,9 @@ LangChain provides several advanced retrieval types. A full list is below, along | [Multi Vector](/docs/how_to/multi_vector/) | Vectorstore + Document Store | Sometimes during indexing | If you are able to extract information from documents that you think is more relevant to index than the text itself. | This involves creating multiple vectors for each document. Each vector could be created in a myriad of ways - examples include summaries of the text and hypothetical questions. | | [Self Query](/docs/how_to/self_query/) | Vectorstore | Yes | If users are asking questions that are better answered by fetching documents based on metadata rather than similarity with the text. | This uses an LLM to transform user input into two things: (1) a string to look up semantically, (2) a metadata filer to go along with it. This is useful because oftentimes questions are about the METADATA of documents (not the content itself). | | [Contextual Compression](/docs/how_to/contextual_compression/) | Any | Sometimes | If you are finding that your retrieved documents contain too much irrelevant information and are distracting the LLM. | This puts a post-processing step on top of another retriever and extracts only the most relevant information from retrieved documents. This can be done with embeddings or an LLM. | -| [Time-Weighted Vectorstore](/docs/how_to/time_weighted_vectorstore/) | Vectorstore | No | If you have timestamps associated with your documents, and you want to retrieve the most recent ones | This fetches documents based on a combination of semantic similarity (as in normal vector retrieval) and recency (looking at timestamps of indexed documents) | -| [Multi-Query Retriever](/docs/how_to/multiple_queries/) | Any | Yes | If users are asking questions that are complex and require multiple pieces of distinct information to respond | This uses an LLM to generate multiple queries from the original one. This is useful when the original query needs pieces of information about multiple topics to be properly answered. By generating multiple queries, we can then fetch documents for each of them. | +| [Time-Weighted Vectorstore](/docs/how_to/time_weighted_vectorstore/) | Vectorstore | No | If you have timestamps associated with your documents, and you want to retrieve the most recent ones. | This fetches documents based on a combination of semantic similarity (as in normal vector retrieval) and recency (looking at timestamps of indexed documents) | +| [Multi-Query Retriever](/docs/how_to/multiple_queries/) | Any | Yes | If users are asking questions that are complex and require multiple pieces of distinct information to respond. | This uses an LLM to generate multiple queries from the original one. This is useful when the original query needs pieces of information about multiple topics to be properly answered. By generating multiple queries, we can then fetch documents for each of them. | +| [Ensemble](/docs/how_to/ensemble_retriever) | Any | No | If you have multiple retrieval methods and want to try combining them. | This fetches documents from multiple retrievers and then combines them. | ### Text splitting diff --git a/docs/core_docs/docs/how_to/ensemble_retriever.mdx b/docs/core_docs/docs/how_to/ensemble_retriever.mdx new file mode 100644 index 000000000000..6eaf23871a33 --- /dev/null +++ b/docs/core_docs/docs/how_to/ensemble_retriever.mdx @@ -0,0 +1,29 @@ +# How to combine results from multiple retrievers + +:::info Prerequisites + +This guide assumes familiarity with the following concepts: + +- [Documents](/docs/concepts#document) +- [Retrievers](/docs/concepts#retrievers) + +::: + +The [EnsembleRetriever](https://api.js.langchain.com/classes/langchain_retrievers_ensemble.EnsembleRetriever.html) supports ensembling of results from multiple retrievers. It is initialized with a list of [BaseRetriever](https://api.js.langchain.com/classes/langchain_core_retrievers.BaseRetriever.html) objects. EnsembleRetrievers rerank the results of the constituent retrievers based on the [Reciprocal Rank Fusion](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) algorithm. + +By leveraging the strengths of different algorithms, the `EnsembleRetriever` can achieve better performance than any single algorithm. + +One useful pattern is to combine a keyword matching retriever with a dense retriever (like embedding similarity), because their strengths are complementary. This can be considered a form of "hybrid search". The sparse retriever is good at finding relevant documents based on keywords, while the dense retriever is good at finding relevant documents based on semantic similarity. + +Below we demonstrate ensembling of a [simple custom retriever](/docs/how_to/custom_retriever/) that simply returns documents that directly contain the input query with a retriever derived from a [demo, in-memory, vector store](https://api.js.langchain.com/classes/langchain_vectorstores_memory.MemoryVectorStore.html). + +import CodeBlock from "@theme/CodeBlock"; +import Example from "@examples/retrievers/ensemble_retriever.ts"; + +{Example} + +## Next steps + +You've now learned how to combine results from multiple retrievers. +Next, check out some other retrieval how-to guides, such as how to [improve results using multiple embeddings per document](/docs/how_to/multi_vector) +or how to [create your own custom retriever](/docs/how_to/custom_retriever). diff --git a/docs/core_docs/docs/how_to/index.mdx b/docs/core_docs/docs/how_to/index.mdx index 969e99689af7..4ac9ef722ae8 100644 --- a/docs/core_docs/docs/how_to/index.mdx +++ b/docs/core_docs/docs/how_to/index.mdx @@ -133,6 +133,7 @@ Retrievers are responsible for taking a query and returning relevant documents. - [How to: generate multiple queries to retrieve data for](/docs/how_to/multiple_queries) - [How to: use contextual compression to compress the data retrieved](/docs/how_to/contextual_compression) - [How to: write a custom retriever class](/docs/how_to/custom_retriever) +- [How to: combine the results from multiple retrievers](/docs/how_to/ensemble_retriever) - [How to: generate multiple embeddings per document](/docs/how_to/multi_vector) - [How to: retrieve the whole document for a chunk](/docs/how_to/parent_document_retriever) - [How to: generate metadata filters](/docs/how_to/self_query) diff --git a/environment_tests/test-exports-bun/src/entrypoints.js b/environment_tests/test-exports-bun/src/entrypoints.js index 068747384a8d..0127a63d1c08 100644 --- a/environment_tests/test-exports-bun/src/entrypoints.js +++ b/environment_tests/test-exports-bun/src/entrypoints.js @@ -37,6 +37,7 @@ export * from "langchain/callbacks"; export * from "langchain/output_parsers"; export * from "langchain/retrievers/contextual_compression"; export * from "langchain/retrievers/document_compressors"; +export * from "langchain/retrievers/ensemble"; export * from "langchain/retrievers/multi_query"; export * from "langchain/retrievers/multi_vector"; export * from "langchain/retrievers/parent_document"; diff --git a/environment_tests/test-exports-cf/src/entrypoints.js b/environment_tests/test-exports-cf/src/entrypoints.js index 068747384a8d..0127a63d1c08 100644 --- a/environment_tests/test-exports-cf/src/entrypoints.js +++ b/environment_tests/test-exports-cf/src/entrypoints.js @@ -37,6 +37,7 @@ export * from "langchain/callbacks"; export * from "langchain/output_parsers"; export * from "langchain/retrievers/contextual_compression"; export * from "langchain/retrievers/document_compressors"; +export * from "langchain/retrievers/ensemble"; export * from "langchain/retrievers/multi_query"; export * from "langchain/retrievers/multi_vector"; export * from "langchain/retrievers/parent_document"; diff --git a/environment_tests/test-exports-cjs/src/entrypoints.js b/environment_tests/test-exports-cjs/src/entrypoints.js index d081d45f6aeb..5f9a19db39f2 100644 --- a/environment_tests/test-exports-cjs/src/entrypoints.js +++ b/environment_tests/test-exports-cjs/src/entrypoints.js @@ -37,6 +37,7 @@ const callbacks = require("langchain/callbacks"); const output_parsers = require("langchain/output_parsers"); const retrievers_contextual_compression = require("langchain/retrievers/contextual_compression"); const retrievers_document_compressors = require("langchain/retrievers/document_compressors"); +const retrievers_ensemble = require("langchain/retrievers/ensemble"); const retrievers_multi_query = require("langchain/retrievers/multi_query"); const retrievers_multi_vector = require("langchain/retrievers/multi_vector"); const retrievers_parent_document = require("langchain/retrievers/parent_document"); diff --git a/environment_tests/test-exports-esbuild/src/entrypoints.js b/environment_tests/test-exports-esbuild/src/entrypoints.js index 4b8bd265fff9..d3b76a743d8a 100644 --- a/environment_tests/test-exports-esbuild/src/entrypoints.js +++ b/environment_tests/test-exports-esbuild/src/entrypoints.js @@ -37,6 +37,7 @@ import * as callbacks from "langchain/callbacks"; import * as output_parsers from "langchain/output_parsers"; import * as retrievers_contextual_compression from "langchain/retrievers/contextual_compression"; import * as retrievers_document_compressors from "langchain/retrievers/document_compressors"; +import * as retrievers_ensemble from "langchain/retrievers/ensemble"; import * as retrievers_multi_query from "langchain/retrievers/multi_query"; import * as retrievers_multi_vector from "langchain/retrievers/multi_vector"; import * as retrievers_parent_document from "langchain/retrievers/parent_document"; diff --git a/environment_tests/test-exports-esm/src/entrypoints.js b/environment_tests/test-exports-esm/src/entrypoints.js index 4b8bd265fff9..d3b76a743d8a 100644 --- a/environment_tests/test-exports-esm/src/entrypoints.js +++ b/environment_tests/test-exports-esm/src/entrypoints.js @@ -37,6 +37,7 @@ import * as callbacks from "langchain/callbacks"; import * as output_parsers from "langchain/output_parsers"; import * as retrievers_contextual_compression from "langchain/retrievers/contextual_compression"; import * as retrievers_document_compressors from "langchain/retrievers/document_compressors"; +import * as retrievers_ensemble from "langchain/retrievers/ensemble"; import * as retrievers_multi_query from "langchain/retrievers/multi_query"; import * as retrievers_multi_vector from "langchain/retrievers/multi_vector"; import * as retrievers_parent_document from "langchain/retrievers/parent_document"; diff --git a/environment_tests/test-exports-vercel/src/entrypoints.js b/environment_tests/test-exports-vercel/src/entrypoints.js index 068747384a8d..0127a63d1c08 100644 --- a/environment_tests/test-exports-vercel/src/entrypoints.js +++ b/environment_tests/test-exports-vercel/src/entrypoints.js @@ -37,6 +37,7 @@ export * from "langchain/callbacks"; export * from "langchain/output_parsers"; export * from "langchain/retrievers/contextual_compression"; export * from "langchain/retrievers/document_compressors"; +export * from "langchain/retrievers/ensemble"; export * from "langchain/retrievers/multi_query"; export * from "langchain/retrievers/multi_vector"; export * from "langchain/retrievers/parent_document"; diff --git a/environment_tests/test-exports-vite/src/entrypoints.js b/environment_tests/test-exports-vite/src/entrypoints.js index 068747384a8d..0127a63d1c08 100644 --- a/environment_tests/test-exports-vite/src/entrypoints.js +++ b/environment_tests/test-exports-vite/src/entrypoints.js @@ -37,6 +37,7 @@ export * from "langchain/callbacks"; export * from "langchain/output_parsers"; export * from "langchain/retrievers/contextual_compression"; export * from "langchain/retrievers/document_compressors"; +export * from "langchain/retrievers/ensemble"; export * from "langchain/retrievers/multi_query"; export * from "langchain/retrievers/multi_vector"; export * from "langchain/retrievers/parent_document"; diff --git a/examples/src/retrievers/ensemble_retriever.ts b/examples/src/retrievers/ensemble_retriever.ts new file mode 100644 index 000000000000..e8fc3a15c874 --- /dev/null +++ b/examples/src/retrievers/ensemble_retriever.ts @@ -0,0 +1,67 @@ +import { EnsembleRetriever } from "langchain/retrievers/ensemble"; +import { MemoryVectorStore } from "langchain/vectorstores/memory"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers"; +import { Document } from "@langchain/core/documents"; + +class SimpleCustomRetriever extends BaseRetriever { + lc_namespace = []; + + documents: Document[]; + + constructor(fields: { documents: Document[] } & BaseRetrieverInput) { + super(fields); + this.documents = fields.documents; + } + + async _getRelevantDocuments(query: string): Promise { + return this.documents.filter((document) => + document.pageContent.includes(query) + ); + } +} + +const docs1 = [ + new Document({ pageContent: "I like apples", metadata: { source: 1 } }), + new Document({ pageContent: "I like oranges", metadata: { source: 1 } }), + new Document({ + pageContent: "apples and oranges are fruits", + metadata: { source: 1 }, + }), +]; + +const keywordRetriever = new SimpleCustomRetriever({ documents: docs1 }); + +const docs2 = [ + new Document({ pageContent: "You like apples", metadata: { source: 2 } }), + new Document({ pageContent: "You like oranges", metadata: { source: 2 } }), +]; + +const vectorstore = await MemoryVectorStore.fromDocuments( + docs2, + new OpenAIEmbeddings() +); + +const vectorstoreRetriever = vectorstore.asRetriever(); + +const retriever = new EnsembleRetriever({ + retrievers: [vectorstoreRetriever, keywordRetriever], + weights: [0.5, 0.5], +}); + +const query = "apples"; +const retrievedDocs = await retriever.invoke(query); + +console.log(retrievedDocs); + +/* + [ + Document { pageContent: 'You like apples', metadata: { source: 2 } }, + Document { pageContent: 'I like apples', metadata: { source: 1 } }, + Document { pageContent: 'You like oranges', metadata: { source: 2 } }, + Document { + pageContent: 'apples and oranges are fruits', + metadata: { source: 1 } + } + ] +*/ diff --git a/langchain/.gitignore b/langchain/.gitignore index 694b69c6f9d8..768344a3b89f 100644 --- a/langchain/.gitignore +++ b/langchain/.gitignore @@ -358,6 +358,10 @@ retrievers/document_compressors.cjs retrievers/document_compressors.js retrievers/document_compressors.d.ts retrievers/document_compressors.d.cts +retrievers/ensemble.cjs +retrievers/ensemble.js +retrievers/ensemble.d.ts +retrievers/ensemble.d.cts retrievers/multi_query.cjs retrievers/multi_query.js retrievers/multi_query.d.ts diff --git a/langchain/langchain.config.js b/langchain/langchain.config.js index 1cd47cf5d1e8..5700b59dd7bc 100644 --- a/langchain/langchain.config.js +++ b/langchain/langchain.config.js @@ -143,6 +143,7 @@ export const config = { // retrievers "retrievers/contextual_compression": "retrievers/contextual_compression", "retrievers/document_compressors": "retrievers/document_compressors/index", + "retrievers/ensemble": "retrievers/ensemble", "retrievers/multi_query": "retrievers/multi_query", "retrievers/multi_vector": "retrievers/multi_vector", "retrievers/parent_document": "retrievers/parent_document", diff --git a/langchain/package.json b/langchain/package.json index 5510d32f70ed..fc8ab7ce097b 100644 --- a/langchain/package.json +++ b/langchain/package.json @@ -370,6 +370,10 @@ "retrievers/document_compressors.js", "retrievers/document_compressors.d.ts", "retrievers/document_compressors.d.cts", + "retrievers/ensemble.cjs", + "retrievers/ensemble.js", + "retrievers/ensemble.d.ts", + "retrievers/ensemble.d.cts", "retrievers/multi_query.cjs", "retrievers/multi_query.js", "retrievers/multi_query.d.ts", @@ -1725,6 +1729,15 @@ "import": "./retrievers/document_compressors.js", "require": "./retrievers/document_compressors.cjs" }, + "./retrievers/ensemble": { + "types": { + "import": "./retrievers/ensemble.d.ts", + "require": "./retrievers/ensemble.d.cts", + "default": "./retrievers/ensemble.d.ts" + }, + "import": "./retrievers/ensemble.js", + "require": "./retrievers/ensemble.cjs" + }, "./retrievers/multi_query": { "types": { "import": "./retrievers/multi_query.d.ts", diff --git a/langchain/src/load/import_map.ts b/langchain/src/load/import_map.ts index 2ac351c8fdba..115793fcac92 100644 --- a/langchain/src/load/import_map.ts +++ b/langchain/src/load/import_map.ts @@ -33,6 +33,7 @@ export * as callbacks from "../callbacks/index.js"; export * as output_parsers from "../output_parsers/index.js"; export * as retrievers__contextual_compression from "../retrievers/contextual_compression.js"; export * as retrievers__document_compressors from "../retrievers/document_compressors/index.js"; +export * as retrievers__ensemble from "../retrievers/ensemble.js"; export * as retrievers__multi_query from "../retrievers/multi_query.js"; export * as retrievers__multi_vector from "../retrievers/multi_vector.js"; export * as retrievers__parent_document from "../retrievers/parent_document.js"; diff --git a/langchain/src/retrievers/ensemble.ts b/langchain/src/retrievers/ensemble.ts new file mode 100644 index 000000000000..606b1cf79c22 --- /dev/null +++ b/langchain/src/retrievers/ensemble.ts @@ -0,0 +1,119 @@ +import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers"; +import { Document, DocumentInterface } from "@langchain/core/documents"; +import { CallbackManagerForRetrieverRun } from "@langchain/core/callbacks/manager"; + +export interface EnsembleRetrieverInput extends BaseRetrieverInput { + /** A list of retrievers to ensemble. */ + retrievers: BaseRetriever[]; + /** + * A list of weights corresponding to the retrievers. Defaults to equal + * weighting for all retrievers. + */ + weights?: number[]; + /** + * A constant added to the rank, controlling the balance between the importance + * of high-ranked items and the consideration given to lower-ranked items. + * Default is 60. + */ + c?: number; +} + +/** + * Ensemble retriever that aggregates and orders the results of + * multiple retrievers by using weighted Reciprocal Rank Fusion. + */ +export class EnsembleRetriever extends BaseRetriever { + static lc_name() { + return "EnsembleRetriever"; + } + + lc_namespace = ["langchain", "retrievers", "ensemble_retriever"]; + + retrievers: BaseRetriever[]; + + weights: number[]; + + c = 60; + + constructor(args: EnsembleRetrieverInput) { + super(args); + this.retrievers = args.retrievers; + this.weights = + args.weights || + new Array(args.retrievers.length).fill(1 / args.retrievers.length); + this.c = args.c || 60; + } + + async _getRelevantDocuments( + query: string, + runManager?: CallbackManagerForRetrieverRun + ) { + return this._rankFusion(query, runManager); + } + + async _rankFusion( + query: string, + runManager?: CallbackManagerForRetrieverRun + ) { + const retrieverDocs = await Promise.all( + this.retrievers.map((retriever, i) => + retriever.invoke(query, { + callbacks: runManager?.getChild(`retriever_${i + 1}`), + }) + ) + ); + + const fusedDocs = await this._weightedReciprocalRank(retrieverDocs); + return fusedDocs; + } + + async _weightedReciprocalRank(docList: DocumentInterface[][]) { + if (docList.length !== this.weights.length) { + throw new Error( + "Number of retrieved document lists must be equal to the number of weights." + ); + } + + const rrfScoreDict = docList.reduce( + (rffScore: Record, retrieverDoc, idx) => { + let rank = 1; + const weight = this.weights[idx]; + while (rank <= retrieverDoc.length) { + const { pageContent } = retrieverDoc[rank - 1]; + if (!rffScore[pageContent]) { + // eslint-disable-next-line no-param-reassign + rffScore[pageContent] = 0; + } + // eslint-disable-next-line no-param-reassign + rffScore[pageContent] += weight / (rank + this.c); + rank += 1; + } + + return rffScore; + }, + {} + ); + + const uniqueDocs = this._uniqueUnion(docList.flat()); + const sortedDocs = Array.from(uniqueDocs).sort( + (a, b) => rrfScoreDict[b.pageContent] - rrfScoreDict[a.pageContent] + ); + + return sortedDocs; + } + + private _uniqueUnion(documents: Document[]): Document[] { + const documentSet = new Set(); + const result = []; + + for (const doc of documents) { + const key = doc.pageContent; + if (!documentSet.has(key)) { + documentSet.add(key); + result.push(doc); + } + } + + return result; + } +} diff --git a/langchain/src/retrievers/tests/ensemble_retriever.int.test.ts b/langchain/src/retrievers/tests/ensemble_retriever.int.test.ts new file mode 100644 index 000000000000..781dfb859321 --- /dev/null +++ b/langchain/src/retrievers/tests/ensemble_retriever.int.test.ts @@ -0,0 +1,104 @@ +import { expect, test } from "@jest/globals"; +import { CohereEmbeddings } from "@langchain/cohere"; +import { MemoryVectorStore } from "../../vectorstores/memory.js"; +import { EnsembleRetriever } from "../ensemble.js"; + +test("Should work with a question input", async () => { + const vectorstore = await MemoryVectorStore.fromTexts( + [ + "Buildings are made out of brick", + "Buildings are made out of wood", + "Buildings are made out of stone", + "Cars are made out of metal", + "Cars are made out of plastic", + "mitochondria is the powerhouse of the cell", + "mitochondria is made of lipids", + ], + [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], + new CohereEmbeddings() + ); + const retriever = new EnsembleRetriever({ + retrievers: [vectorstore.asRetriever()], + }); + + const query = "What are mitochondria made of?"; + const retrievedDocs = await retriever.invoke(query); + expect(retrievedDocs[0].pageContent).toContain("mitochondria"); +}); + +test("Should work with multiple retriever", async () => { + const vectorstore = await MemoryVectorStore.fromTexts( + [ + "Buildings are made out of brick", + "Buildings are made out of wood", + "Buildings are made out of stone", + "Cars are made out of metal", + "Cars are made out of plastic", + "mitochondria is the powerhouse of the cell", + "mitochondria is made of lipids", + ], + [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], + new CohereEmbeddings() + ); + const vectorstore2 = await MemoryVectorStore.fromTexts( + [ + "Buildings are made out of brick", + "Buildings are made out of wood", + "Buildings are made out of stone", + "Cars are made out of metal", + "Cars are made out of plastic", + "mitochondria is the powerhouse of the cell", + "mitochondria is made of lipids", + ], + [{ id: 6 }, { id: 7 }, { id: 8 }, { id: 9 }, { id: 10 }], + new CohereEmbeddings() + ); + const retriever = new EnsembleRetriever({ + retrievers: [vectorstore.asRetriever(), vectorstore2.asRetriever()], + }); + + const query = "cars"; + const retrievedDocs = await retriever.invoke(query); + expect( + retrievedDocs.filter((item) => item.pageContent.includes("Cars")).length + ).toBe(2); +}); + +test("Should work with weights", async () => { + const vectorstore = await MemoryVectorStore.fromTexts( + [ + "Buildings are made out of brick", + "Buildings are made out of wood", + "Buildings are made out of stone", + "Cars are made out of metal", + "Cars are made out of plastic", + "mitochondria is the powerhouse of the cell", + "mitochondria is made of lipids", + ], + [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], + new CohereEmbeddings() + ); + const vectorstore2 = await MemoryVectorStore.fromTexts( + [ + "Buildings are made out of brick", + "Buildings are made out of wood", + "Buildings are made out of stone", + "Cars are made out of metal", + "Cars are made out of plastic", + "mitochondria is the powerhouse of the cell", + "mitochondria is made of lipids", + ], + [{ id: 6 }, { id: 7 }, { id: 8 }, { id: 9 }, { id: 10 }], + new CohereEmbeddings() + ); + const retriever = new EnsembleRetriever({ + retrievers: [vectorstore.asRetriever(), vectorstore2.asRetriever()], + weights: [0.5, 0.9], + }); + + const query = "cars"; + const retrievedDocs = await retriever.invoke(query); + expect( + retrievedDocs.filter((item) => item.pageContent.includes("Cars")).length + ).toBe(2); +});