diff --git a/examples/src/document_compressors/cohere_rerank.ts b/examples/src/document_compressors/cohere_rerank.ts index 45443b728b57..21013dc65622 100644 --- a/examples/src/document_compressors/cohere_rerank.ts +++ b/examples/src/document_compressors/cohere_rerank.ts @@ -27,7 +27,7 @@ const docs = [ const cohereRerank = new CohereRerank({ apiKey: process.env.COHERE_API_KEY, // Default - model: "rerank-english-v2.0", // Default + model: "rerank-english-v2.0", }); const rerankedDocuments = await cohereRerank.rerank(docs, query, { diff --git a/examples/src/document_compressors/cohere_rerank_compressor.ts b/examples/src/document_compressors/cohere_rerank_compressor.ts index 941e28f32e91..8ff678cf79a4 100644 --- a/examples/src/document_compressors/cohere_rerank_compressor.ts +++ b/examples/src/document_compressors/cohere_rerank_compressor.ts @@ -28,7 +28,7 @@ const docs = [ const cohereRerank = new CohereRerank({ apiKey: process.env.COHERE_API_KEY, // Default topN: 3, // Default - model: "rerank-english-v2.0", // Default + model: "rerank-english-v2.0", }); const rerankedDocuments = await cohereRerank.compressDocuments(docs, query); diff --git a/examples/src/embeddings/cohere.ts b/examples/src/embeddings/cohere.ts index c17c4ef1bbb4..258afd41979c 100644 --- a/examples/src/embeddings/cohere.ts +++ b/examples/src/embeddings/cohere.ts @@ -1,7 +1,7 @@ import { CohereEmbeddings } from "@langchain/cohere"; export const run = async () => { - const model = new CohereEmbeddings(); + const model = new CohereEmbeddings({ model: "embed-english-v3.0" }); const res = await model.embedQuery( "What would be a good company name a company that makes colorful socks?" ); diff --git a/examples/src/guides/expression_language/runnable_maps_sequence.ts b/examples/src/guides/expression_language/runnable_maps_sequence.ts index de783221ea77..1f66d453c457 100644 --- a/examples/src/guides/expression_language/runnable_maps_sequence.ts +++ b/examples/src/guides/expression_language/runnable_maps_sequence.ts @@ -12,7 +12,7 @@ import { MemoryVectorStore } from "langchain/vectorstores/memory"; const model = new ChatAnthropic(); const vectorstore = await MemoryVectorStore.fromDocuments( [{ pageContent: "mitochondria is the powerhouse of the cell", metadata: {} }], - new CohereEmbeddings() + new CohereEmbeddings({ model: "embed-english-v3.0" }) ); const retriever = vectorstore.asRetriever(); const template = `Answer the question based only on the following context: diff --git a/examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts b/examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts index 1ba5042fc68f..8770f6949494 100755 --- a/examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts +++ b/examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts @@ -10,7 +10,7 @@ const collection = client.db(dbName).collection(collectionName); const vectorstore = await MongoDBAtlasVectorSearch.fromTexts( ["Hello world", "Bye bye", "What's this?"], [{ id: 2 }, { id: 1 }, { id: 3 }], - new CohereEmbeddings(), + new CohereEmbeddings({ model: "embed-english-v3.0" }), { collection, indexName: "default", // The name of the Atlas search index. Defaults to "default" diff --git a/examples/src/indexes/vector_stores/mongodb_atlas_search.ts b/examples/src/indexes/vector_stores/mongodb_atlas_search.ts index 47da29b8b77e..f714b02bf5aa 100755 --- a/examples/src/indexes/vector_stores/mongodb_atlas_search.ts +++ b/examples/src/indexes/vector_stores/mongodb_atlas_search.ts @@ -7,12 +7,15 @@ const namespace = "langchain.test"; const [dbName, collectionName] = namespace.split("."); const collection = client.db(dbName).collection(collectionName); -const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), { - collection, - indexName: "default", // The name of the Atlas search index. Defaults to "default" - textKey: "text", // The name of the collection field containing the raw content. Defaults to "text" - embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding" -}); +const vectorStore = new MongoDBAtlasVectorSearch( + new CohereEmbeddings({ model: "embed-english-v3.0" }), + { + collection, + indexName: "default", // The name of the Atlas search index. Defaults to "default" + textKey: "text", // The name of the collection field containing the raw content. Defaults to "text" + embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding" + } +); const resultOne = await vectorStore.similaritySearch("Hello world", 1); console.log(resultOne); diff --git a/examples/src/indexes/vector_stores/mongodb_metadata_filtering.ts b/examples/src/indexes/vector_stores/mongodb_metadata_filtering.ts index 1c76da52c8cc..66e0a63c82e1 100644 --- a/examples/src/indexes/vector_stores/mongodb_metadata_filtering.ts +++ b/examples/src/indexes/vector_stores/mongodb_metadata_filtering.ts @@ -9,12 +9,15 @@ const namespace = "langchain.test"; const [dbName, collectionName] = namespace.split("."); const collection = client.db(dbName).collection(collectionName); -const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), { - collection, - indexName: "default", // The name of the Atlas search index. Defaults to "default" - textKey: "text", // The name of the collection field containing the raw content. Defaults to "text" - embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding" -}); +const vectorStore = new MongoDBAtlasVectorSearch( + new CohereEmbeddings({ model: "embed-english-v3.0" }), + { + collection, + indexName: "default", // The name of the Atlas search index. Defaults to "default" + textKey: "text", // The name of the collection field containing the raw content. Defaults to "text" + embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding" + } +); await vectorStore.addDocuments([ { diff --git a/examples/src/indexes/vector_stores/mongodb_mmr.ts b/examples/src/indexes/vector_stores/mongodb_mmr.ts index 6dada1c6779c..2419b84651d0 100644 --- a/examples/src/indexes/vector_stores/mongodb_mmr.ts +++ b/examples/src/indexes/vector_stores/mongodb_mmr.ts @@ -7,12 +7,15 @@ const namespace = "langchain.test"; const [dbName, collectionName] = namespace.split("."); const collection = client.db(dbName).collection(collectionName); -const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), { - collection, - indexName: "default", // The name of the Atlas search index. Defaults to "default" - textKey: "text", // The name of the collection field containing the raw content. Defaults to "text" - embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding" -}); +const vectorStore = new MongoDBAtlasVectorSearch( + new CohereEmbeddings({ model: "embed-english-v3.0" }), + { + collection, + indexName: "default", // The name of the Atlas search index. Defaults to "default" + textKey: "text", // The name of the collection field containing the raw content. Defaults to "text" + embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding" + } +); const resultOne = await vectorStore.maxMarginalRelevanceSearch("Hello world", { k: 4, diff --git a/examples/src/indexes/vector_stores/vercel_postgres/example.ts b/examples/src/indexes/vector_stores/vercel_postgres/example.ts index 0d4f102f839a..d53e8bf1c78b 100644 --- a/examples/src/indexes/vector_stores/vercel_postgres/example.ts +++ b/examples/src/indexes/vector_stores/vercel_postgres/example.ts @@ -16,7 +16,7 @@ const config = { }; const vercelPostgresStore = await VercelPostgres.initialize( - new CohereEmbeddings(), + new CohereEmbeddings({ model: "embed-english-v3.0" }), config ); diff --git a/examples/src/models/embeddings/cohere.ts b/examples/src/models/embeddings/cohere.ts index 925d2f4324b4..1f865f498c11 100644 --- a/examples/src/models/embeddings/cohere.ts +++ b/examples/src/models/embeddings/cohere.ts @@ -4,6 +4,7 @@ import { CohereEmbeddings } from "@langchain/cohere"; const embeddings = new CohereEmbeddings({ apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.COHERE_API_KEY batchSize: 48, // Default value if omitted is 48. Max value is 96 + model: "embed-english-v3.0", }); const res = await embeddings.embedQuery("Hello world"); console.log(res); diff --git a/examples/src/retrievers/multi_query.ts b/examples/src/retrievers/multi_query.ts index 97205e7a2eea..de5c7f321168 100644 --- a/examples/src/retrievers/multi_query.ts +++ b/examples/src/retrievers/multi_query.ts @@ -14,7 +14,7 @@ const vectorstore = await MemoryVectorStore.fromTexts( "mitochondria is made of lipids", ], [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], - new CohereEmbeddings() + new CohereEmbeddings({ model: "embed-english-v3.0" }) ); const model = new ChatAnthropic({}); const retriever = MultiQueryRetriever.fromLLM({ diff --git a/examples/src/retrievers/multi_query_custom.ts b/examples/src/retrievers/multi_query_custom.ts index 43b9b009b975..c3c97fb6f781 100644 --- a/examples/src/retrievers/multi_query_custom.ts +++ b/examples/src/retrievers/multi_query_custom.ts @@ -1,5 +1,5 @@ import { MemoryVectorStore } from "langchain/vectorstores/memory"; -import { CohereEmbeddings } from "@langchain/community/embeddings/cohere"; +import { CohereEmbeddings } from "@langchain/cohere"; import { MultiQueryRetriever } from "langchain/retrievers/multi_query"; import { LLMChain } from "langchain/chains"; import { pull } from "langchain/hub"; @@ -53,7 +53,7 @@ const vectorstore = await MemoryVectorStore.fromTexts( "Mitochondrien bestehen aus Lipiden", ], [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], - new CohereEmbeddings() + new CohereEmbeddings({ model: "embed-english-v3.0" }) ); const model = new ChatAnthropic({}); const llmChain = new LLMChain({ diff --git a/langchain/src/retrievers/tests/ensemble_retriever.int.test.ts b/langchain/src/retrievers/tests/ensemble_retriever.int.test.ts index 781dfb859321..097147f1a513 100644 --- a/langchain/src/retrievers/tests/ensemble_retriever.int.test.ts +++ b/langchain/src/retrievers/tests/ensemble_retriever.int.test.ts @@ -15,7 +15,7 @@ test("Should work with a question input", async () => { "mitochondria is made of lipids", ], [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], - new CohereEmbeddings() + new CohereEmbeddings({ model: "embed-english-v3.0" }) ); const retriever = new EnsembleRetriever({ retrievers: [vectorstore.asRetriever()], @@ -38,7 +38,7 @@ test("Should work with multiple retriever", async () => { "mitochondria is made of lipids", ], [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], - new CohereEmbeddings() + new CohereEmbeddings({ model: "embed-english-v3.0" }) ); const vectorstore2 = await MemoryVectorStore.fromTexts( [ @@ -51,7 +51,7 @@ test("Should work with multiple retriever", async () => { "mitochondria is made of lipids", ], [{ id: 6 }, { id: 7 }, { id: 8 }, { id: 9 }, { id: 10 }], - new CohereEmbeddings() + new CohereEmbeddings({ model: "embed-english-v3.0" }) ); const retriever = new EnsembleRetriever({ retrievers: [vectorstore.asRetriever(), vectorstore2.asRetriever()], @@ -76,7 +76,7 @@ test("Should work with weights", async () => { "mitochondria is made of lipids", ], [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], - new CohereEmbeddings() + new CohereEmbeddings({ model: "embed-english-v3.0" }) ); const vectorstore2 = await MemoryVectorStore.fromTexts( [ @@ -89,7 +89,7 @@ test("Should work with weights", async () => { "mitochondria is made of lipids", ], [{ id: 6 }, { id: 7 }, { id: 8 }, { id: 9 }, { id: 10 }], - new CohereEmbeddings() + new CohereEmbeddings({ model: "embed-english-v3.0" }) ); const retriever = new EnsembleRetriever({ retrievers: [vectorstore.asRetriever(), vectorstore2.asRetriever()], diff --git a/langchain/src/retrievers/tests/multi_query.int.test.ts b/langchain/src/retrievers/tests/multi_query.int.test.ts index 18792bfba613..29263997fa51 100644 --- a/langchain/src/retrievers/tests/multi_query.int.test.ts +++ b/langchain/src/retrievers/tests/multi_query.int.test.ts @@ -16,7 +16,7 @@ test("Should work with a question input", async () => { "mitochondria is made of lipids", ], [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], - new CohereEmbeddings() + new CohereEmbeddings({ model: "embed-english-v3.0" }) ); const model = new ChatOpenAI({}); const retriever = MultiQueryRetriever.fromLLM({ @@ -42,7 +42,7 @@ test("Should work with a keyword", async () => { "mitochondria is made of lipids", ], [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], - new CohereEmbeddings() + new CohereEmbeddings({ model: "embed-english-v3.0" }) ); const model = new ChatOpenAI({}); const retriever = MultiQueryRetriever.fromLLM({ diff --git a/libs/langchain-cohere/src/embeddings.ts b/libs/langchain-cohere/src/embeddings.ts index 4249c73c8bc3..f04ee1815887 100644 --- a/libs/langchain-cohere/src/embeddings.ts +++ b/libs/langchain-cohere/src/embeddings.ts @@ -9,7 +9,7 @@ import { chunkArray } from "@langchain/core/utils/chunk_array"; * parameters specific to the CohereEmbeddings class. */ export interface CohereEmbeddingsParams extends EmbeddingsParams { - model: string; + model?: string; /** * The maximum number of documents to embed in a single request. This is @@ -17,6 +17,11 @@ export interface CohereEmbeddingsParams extends EmbeddingsParams { */ batchSize?: number; + /** + * Specifies the type of embeddings you want to generate. + */ + embeddingTypes?: Array; + /** * Specifies the type of input you're giving to the model. * Not required for older versions of the embedding models (i.e. anything lower than v3), @@ -37,11 +42,11 @@ export class CohereEmbeddings extends Embeddings implements CohereEmbeddingsParams { - model = "small"; + model: string | undefined; batchSize = 48; - inputType: string | undefined; + embeddingTypes = ["float"]; private client: CohereClient; @@ -70,8 +75,16 @@ export class CohereEmbeddings token: apiKey, }); this.model = fieldsWithDefaults?.model ?? this.model; + + if (!this.model) { + throw new Error( + "Model not specified for CohereEmbeddings instance. Please provide a model name from the options here: https://docs.cohere.com/reference/embed" + ); + } + this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize; - this.inputType = fieldsWithDefaults?.inputType; + this.embeddingTypes = + fieldsWithDefaults?.embeddingTypes ?? this.embeddingTypes; } /** @@ -87,7 +100,9 @@ export class CohereEmbeddings model: this.model, texts: batch, // eslint-disable-next-line @typescript-eslint/no-explicit-any - inputType: this.inputType as any, + inputType: "search_document" as any, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + embeddingTypes: this.embeddingTypes as any, }) ); @@ -120,7 +135,9 @@ export class CohereEmbeddings model: this.model, texts: [text], // eslint-disable-next-line @typescript-eslint/no-explicit-any - inputType: this.inputType as any, + inputType: "search_query" as any, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + embeddingTypes: this.embeddingTypes as any, }); if ("float" in embeddings && embeddings.float) { return embeddings.float[0]; @@ -137,6 +154,25 @@ export class CohereEmbeddings } } + async embed( + request: Parameters[0] + ): Promise { + const { embeddings } = await this.embeddingWithRetry(request); + if ("float" in embeddings && embeddings.float) { + return embeddings.float[0]; + } else if (Array.isArray(embeddings)) { + return embeddings[0]; + } else { + throw new Error( + `Invalid response from Cohere API. Received: ${JSON.stringify( + embeddings, + null, + 2 + )}` + ); + } + } + /** * Generates embeddings with retry capabilities. * @param request - An object containing the request parameters for generating embeddings. diff --git a/libs/langchain-cohere/src/rerank.ts b/libs/langchain-cohere/src/rerank.ts index 2b29f19dbfb4..21e188cc70ff 100644 --- a/libs/langchain-cohere/src/rerank.ts +++ b/libs/langchain-cohere/src/rerank.ts @@ -29,7 +29,7 @@ export interface CohereRerankArgs { * Document compressor that uses `Cohere Rerank API`. */ export class CohereRerank extends BaseDocumentCompressor { - model = "rerank-english-v2.0"; + model: string | undefined; topN = 3; @@ -48,6 +48,11 @@ export class CohereRerank extends BaseDocumentCompressor { token, }); this.model = fields?.model ?? this.model; + if (!this.model) { + throw new Error( + "Model not specified for CohereRerank instance. Please provide a model name from the options here: https://docs.cohere.com/reference/rerank" + ); + } this.topN = fields?.topN ?? this.topN; this.maxChunksPerDoc = fields?.maxChunksPerDoc; } diff --git a/libs/langchain-cohere/src/tests/embeddings.int.test.ts b/libs/langchain-cohere/src/tests/embeddings.int.test.ts index cd5751a440a3..16c119cedd73 100644 --- a/libs/langchain-cohere/src/tests/embeddings.int.test.ts +++ b/libs/langchain-cohere/src/tests/embeddings.int.test.ts @@ -2,13 +2,13 @@ import { test, expect } from "@jest/globals"; import { CohereEmbeddings } from "../embeddings.js"; test("Test CohereEmbeddings.embedQuery", async () => { - const embeddings = new CohereEmbeddings(); + const embeddings = new CohereEmbeddings({ model: "small" }); const res = await embeddings.embedQuery("Hello world"); expect(typeof res[0]).toBe("number"); }); test("Test CohereEmbeddings.embedDocuments", async () => { - const embeddings = new CohereEmbeddings(); + const embeddings = new CohereEmbeddings({ model: "small" }); const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]); expect(res).toHaveLength(2); expect(typeof res[0][0]).toBe("number"); @@ -19,6 +19,7 @@ test("Test CohereEmbeddings concurrency", async () => { const embeddings = new CohereEmbeddings({ batchSize: 1, maxConcurrency: 2, + model: "small", }); const res = await embeddings.embedDocuments([ "Hello world", diff --git a/libs/langchain-cohere/src/tests/rerank.int.test.ts b/libs/langchain-cohere/src/tests/rerank.int.test.ts index 93270f36e373..ce09c662d610 100644 --- a/libs/langchain-cohere/src/tests/rerank.int.test.ts +++ b/libs/langchain-cohere/src/tests/rerank.int.test.ts @@ -20,6 +20,7 @@ const documents = [ test("CohereRerank can indeed rerank documents with compressDocuments method", async () => { const cohereRerank = new CohereRerank({ apiKey: process.env.COHERE_API_KEY, + model: "rerank-english-v2.0", }); const rerankedDocuments = await cohereRerank.compressDocuments( @@ -33,6 +34,7 @@ test("CohereRerank can indeed rerank documents with compressDocuments method", a test("CohereRerank can indeed rerank documents with rerank method", async () => { const cohereRerank = new CohereRerank({ apiKey: process.env.COHERE_API_KEY, + model: "rerank-english-v2.0", }); const rerankedDocuments = await cohereRerank.rerank(