Skip to content

Commit

Permalink
cohere[patch]: update embeddings and rerank (#5928)
Browse files Browse the repository at this point in the history
* update embeddings and rerank

* Format

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
Anirudh31415926535 and jacoblee93 authored Jul 9, 2024
1 parent 7bea8d1 commit c2d3472
Show file tree
Hide file tree
Showing 18 changed files with 97 additions and 43 deletions.
2 changes: 1 addition & 1 deletion examples/src/document_compressors/cohere_rerank.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const docs = [

const cohereRerank = new CohereRerank({
apiKey: process.env.COHERE_API_KEY, // Default
model: "rerank-english-v2.0", // Default
model: "rerank-english-v2.0",
});

const rerankedDocuments = await cohereRerank.rerank(docs, query, {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const docs = [
const cohereRerank = new CohereRerank({
apiKey: process.env.COHERE_API_KEY, // Default
topN: 3, // Default
model: "rerank-english-v2.0", // Default
model: "rerank-english-v2.0",
});

const rerankedDocuments = await cohereRerank.compressDocuments(docs, query);
Expand Down
2 changes: 1 addition & 1 deletion examples/src/embeddings/cohere.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { CohereEmbeddings } from "@langchain/cohere";

export const run = async () => {
const model = new CohereEmbeddings();
const model = new CohereEmbeddings({ model: "embed-english-v3.0" });
const res = await model.embedQuery(
"What would be a good company name a company that makes colorful socks?"
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { MemoryVectorStore } from "langchain/vectorstores/memory";
const model = new ChatAnthropic();
const vectorstore = await MemoryVectorStore.fromDocuments(
[{ pageContent: "mitochondria is the powerhouse of the cell", metadata: {} }],
new CohereEmbeddings()
new CohereEmbeddings({ model: "embed-english-v3.0" })
);
const retriever = vectorstore.asRetriever();
const template = `Answer the question based only on the following context:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const collection = client.db(dbName).collection(collectionName);
const vectorstore = await MongoDBAtlasVectorSearch.fromTexts(
["Hello world", "Bye bye", "What's this?"],
[{ id: 2 }, { id: 1 }, { id: 3 }],
new CohereEmbeddings(),
new CohereEmbeddings({ model: "embed-english-v3.0" }),
{
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
Expand Down
15 changes: 9 additions & 6 deletions examples/src/indexes/vector_stores/mongodb_atlas_search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ const namespace = "langchain.test";
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);

const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
});
const vectorStore = new MongoDBAtlasVectorSearch(
new CohereEmbeddings({ model: "embed-english-v3.0" }),
{
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
}
);

const resultOne = await vectorStore.similaritySearch("Hello world", 1);
console.log(resultOne);
Expand Down
15 changes: 9 additions & 6 deletions examples/src/indexes/vector_stores/mongodb_metadata_filtering.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ const namespace = "langchain.test";
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);

const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
});
const vectorStore = new MongoDBAtlasVectorSearch(
new CohereEmbeddings({ model: "embed-english-v3.0" }),
{
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
}
);

await vectorStore.addDocuments([
{
Expand Down
15 changes: 9 additions & 6 deletions examples/src/indexes/vector_stores/mongodb_mmr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ const namespace = "langchain.test";
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);

const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
});
const vectorStore = new MongoDBAtlasVectorSearch(
new CohereEmbeddings({ model: "embed-english-v3.0" }),
{
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
}
);

const resultOne = await vectorStore.maxMarginalRelevanceSearch("Hello world", {
k: 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const config = {
};

const vercelPostgresStore = await VercelPostgres.initialize(
new CohereEmbeddings(),
new CohereEmbeddings({ model: "embed-english-v3.0" }),
config
);

Expand Down
1 change: 1 addition & 0 deletions examples/src/models/embeddings/cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { CohereEmbeddings } from "@langchain/cohere";
const embeddings = new CohereEmbeddings({
apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.COHERE_API_KEY
batchSize: 48, // Default value if omitted is 48. Max value is 96
model: "embed-english-v3.0",
});
const res = await embeddings.embedQuery("Hello world");
console.log(res);
Expand Down
2 changes: 1 addition & 1 deletion examples/src/retrievers/multi_query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const vectorstore = await MemoryVectorStore.fromTexts(
"mitochondria is made of lipids",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({ model: "embed-english-v3.0" })
);
const model = new ChatAnthropic({});
const retriever = MultiQueryRetriever.fromLLM({
Expand Down
4 changes: 2 additions & 2 deletions examples/src/retrievers/multi_query_custom.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { CohereEmbeddings } from "@langchain/community/embeddings/cohere";
import { CohereEmbeddings } from "@langchain/cohere";
import { MultiQueryRetriever } from "langchain/retrievers/multi_query";
import { LLMChain } from "langchain/chains";
import { pull } from "langchain/hub";
Expand Down Expand Up @@ -53,7 +53,7 @@ const vectorstore = await MemoryVectorStore.fromTexts(
"Mitochondrien bestehen aus Lipiden",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({ model: "embed-english-v3.0" })
);
const model = new ChatAnthropic({});
const llmChain = new LLMChain({
Expand Down
10 changes: 5 additions & 5 deletions langchain/src/retrievers/tests/ensemble_retriever.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ test("Should work with a question input", async () => {
"mitochondria is made of lipids",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({ model: "embed-english-v3.0" })
);
const retriever = new EnsembleRetriever({
retrievers: [vectorstore.asRetriever()],
Expand All @@ -38,7 +38,7 @@ test("Should work with multiple retriever", async () => {
"mitochondria is made of lipids",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({ model: "embed-english-v3.0" })
);
const vectorstore2 = await MemoryVectorStore.fromTexts(
[
Expand All @@ -51,7 +51,7 @@ test("Should work with multiple retriever", async () => {
"mitochondria is made of lipids",
],
[{ id: 6 }, { id: 7 }, { id: 8 }, { id: 9 }, { id: 10 }],
new CohereEmbeddings()
new CohereEmbeddings({ model: "embed-english-v3.0" })
);
const retriever = new EnsembleRetriever({
retrievers: [vectorstore.asRetriever(), vectorstore2.asRetriever()],
Expand All @@ -76,7 +76,7 @@ test("Should work with weights", async () => {
"mitochondria is made of lipids",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({ model: "embed-english-v3.0" })
);
const vectorstore2 = await MemoryVectorStore.fromTexts(
[
Expand All @@ -89,7 +89,7 @@ test("Should work with weights", async () => {
"mitochondria is made of lipids",
],
[{ id: 6 }, { id: 7 }, { id: 8 }, { id: 9 }, { id: 10 }],
new CohereEmbeddings()
new CohereEmbeddings({ model: "embed-english-v3.0" })
);
const retriever = new EnsembleRetriever({
retrievers: [vectorstore.asRetriever(), vectorstore2.asRetriever()],
Expand Down
4 changes: 2 additions & 2 deletions langchain/src/retrievers/tests/multi_query.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ test("Should work with a question input", async () => {
"mitochondria is made of lipids",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({ model: "embed-english-v3.0" })
);
const model = new ChatOpenAI({});
const retriever = MultiQueryRetriever.fromLLM({
Expand All @@ -42,7 +42,7 @@ test("Should work with a keyword", async () => {
"mitochondria is made of lipids",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({ model: "embed-english-v3.0" })
);
const model = new ChatOpenAI({});
const retriever = MultiQueryRetriever.fromLLM({
Expand Down
48 changes: 42 additions & 6 deletions libs/langchain-cohere/src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@ import { chunkArray } from "@langchain/core/utils/chunk_array";
* parameters specific to the CohereEmbeddings class.
*/
export interface CohereEmbeddingsParams extends EmbeddingsParams {
model: string;
model?: string;

/**
* The maximum number of documents to embed in a single request. This is
* limited by the Cohere API to a maximum of 96.
*/
batchSize?: number;

/**
* Specifies the type of embeddings you want to generate.
*/
embeddingTypes?: Array<string>;

/**
* Specifies the type of input you're giving to the model.
* Not required for older versions of the embedding models (i.e. anything lower than v3),
Expand All @@ -37,11 +42,11 @@ export class CohereEmbeddings
extends Embeddings
implements CohereEmbeddingsParams
{
model = "small";
model: string | undefined;

batchSize = 48;

inputType: string | undefined;
embeddingTypes = ["float"];

private client: CohereClient;

Expand Down Expand Up @@ -70,8 +75,16 @@ export class CohereEmbeddings
token: apiKey,
});
this.model = fieldsWithDefaults?.model ?? this.model;

if (!this.model) {
throw new Error(
"Model not specified for CohereEmbeddings instance. Please provide a model name from the options here: https://docs.cohere.com/reference/embed"
);
}

this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize;
this.inputType = fieldsWithDefaults?.inputType;
this.embeddingTypes =
fieldsWithDefaults?.embeddingTypes ?? this.embeddingTypes;
}

/**
Expand All @@ -87,7 +100,9 @@ export class CohereEmbeddings
model: this.model,
texts: batch,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
inputType: this.inputType as any,
inputType: "search_document" as any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
embeddingTypes: this.embeddingTypes as any,
})
);

Expand Down Expand Up @@ -120,7 +135,9 @@ export class CohereEmbeddings
model: this.model,
texts: [text],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
inputType: this.inputType as any,
inputType: "search_query" as any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
embeddingTypes: this.embeddingTypes as any,
});
if ("float" in embeddings && embeddings.float) {
return embeddings.float[0];
Expand All @@ -137,6 +154,25 @@ export class CohereEmbeddings
}
}

async embed(
request: Parameters<typeof this.client.embed>[0]
): Promise<number[]> {
const { embeddings } = await this.embeddingWithRetry(request);
if ("float" in embeddings && embeddings.float) {
return embeddings.float[0];
} else if (Array.isArray(embeddings)) {
return embeddings[0];
} else {
throw new Error(
`Invalid response from Cohere API. Received: ${JSON.stringify(
embeddings,
null,
2
)}`
);
}
}

/**
* Generates embeddings with retry capabilities.
* @param request - An object containing the request parameters for generating embeddings.
Expand Down
7 changes: 6 additions & 1 deletion libs/langchain-cohere/src/rerank.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export interface CohereRerankArgs {
* Document compressor that uses `Cohere Rerank API`.
*/
export class CohereRerank extends BaseDocumentCompressor {
model = "rerank-english-v2.0";
model: string | undefined;

topN = 3;

Expand All @@ -48,6 +48,11 @@ export class CohereRerank extends BaseDocumentCompressor {
token,
});
this.model = fields?.model ?? this.model;
if (!this.model) {
throw new Error(
"Model not specified for CohereRerank instance. Please provide a model name from the options here: https://docs.cohere.com/reference/rerank"
);
}
this.topN = fields?.topN ?? this.topN;
this.maxChunksPerDoc = fields?.maxChunksPerDoc;
}
Expand Down
5 changes: 3 additions & 2 deletions libs/langchain-cohere/src/tests/embeddings.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import { test, expect } from "@jest/globals";
import { CohereEmbeddings } from "../embeddings.js";

test("Test CohereEmbeddings.embedQuery", async () => {
const embeddings = new CohereEmbeddings();
const embeddings = new CohereEmbeddings({ model: "small" });
const res = await embeddings.embedQuery("Hello world");
expect(typeof res[0]).toBe("number");
});

test("Test CohereEmbeddings.embedDocuments", async () => {
const embeddings = new CohereEmbeddings();
const embeddings = new CohereEmbeddings({ model: "small" });
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]);
expect(res).toHaveLength(2);
expect(typeof res[0][0]).toBe("number");
Expand All @@ -19,6 +19,7 @@ test("Test CohereEmbeddings concurrency", async () => {
const embeddings = new CohereEmbeddings({
batchSize: 1,
maxConcurrency: 2,
model: "small",
});
const res = await embeddings.embedDocuments([
"Hello world",
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain-cohere/src/tests/rerank.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const documents = [
test("CohereRerank can indeed rerank documents with compressDocuments method", async () => {
const cohereRerank = new CohereRerank({
apiKey: process.env.COHERE_API_KEY,
model: "rerank-english-v2.0",
});

const rerankedDocuments = await cohereRerank.compressDocuments(
Expand All @@ -33,6 +34,7 @@ test("CohereRerank can indeed rerank documents with compressDocuments method", a
test("CohereRerank can indeed rerank documents with rerank method", async () => {
const cohereRerank = new CohereRerank({
apiKey: process.env.COHERE_API_KEY,
model: "rerank-english-v2.0",
});

const rerankedDocuments = await cohereRerank.rerank(
Expand Down

0 comments on commit c2d3472

Please sign in to comment.