Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cohere[patch]: update embeddings and rerank #5928

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/src/document_compressors/cohere_rerank.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const docs = [

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey team, I've flagged a change in the cohereRerank instantiation that accesses an environment variable via process.env. Please review this change to ensure it aligns with our environment variable handling practices.

const cohereRerank = new CohereRerank({
apiKey: process.env.COHERE_API_KEY, // Default
model: "rerank-english-v2.0", // Default
model: "rerank-english-v2.0",
});

const rerankedDocuments = await cohereRerank.rerank(docs, query, {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const docs = [
const cohereRerank = new CohereRerank({
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey team, just a heads up that I've flagged a change in the cohereRerank instantiation that sets the apiKey property using an environment variable. This is for your review to ensure it aligns with our best practices for handling environment variables.

apiKey: process.env.COHERE_API_KEY, // Default
topN: 3, // Default
model: "rerank-english-v2.0", // Default
model: "rerank-english-v2.0",
});

const rerankedDocuments = await cohereRerank.compressDocuments(docs, query);
Expand Down
2 changes: 1 addition & 1 deletion examples/src/embeddings/cohere.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { CohereEmbeddings } from "@langchain/cohere";

export const run = async () => {
const model = new CohereEmbeddings();
const model = new CohereEmbeddings({model: "embed-english-v3.0",});
const res = await model.embedQuery(
"What would be a good company name a company that makes colorful socks?"
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { MemoryVectorStore } from "langchain/vectorstores/memory";
const model = new ChatAnthropic();
const vectorstore = await MemoryVectorStore.fromDocuments(
[{ pageContent: "mitochondria is the powerhouse of the cell", metadata: {} }],
new CohereEmbeddings()
new CohereEmbeddings({model: "embed-english-v3.0",})
);
const retriever = vectorstore.asRetriever();
const template = `Answer the question based only on the following context:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const collection = client.db(dbName).collection(collectionName);
const vectorstore = await MongoDBAtlasVectorSearch.fromTexts(
["Hello world", "Bye bye", "What's this?"],
[{ id: 2 }, { id: 1 }, { id: 3 }],
new CohereEmbeddings(),
new CohereEmbeddings({model: "embed-english-v3.0",}),
{
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
Expand Down
2 changes: 1 addition & 1 deletion examples/src/indexes/vector_stores/mongodb_atlas_search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ const namespace = "langchain.test";
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);

const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings({model: "embed-english-v3.0",}), {
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const namespace = "langchain.test";
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);

const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings({model: "embed-english-v3.0",}), {
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
Expand Down
2 changes: 1 addition & 1 deletion examples/src/indexes/vector_stores/mongodb_mmr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ const namespace = "langchain.test";
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);

const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings({model: "embed-english-v3.0",}), {
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const config = {
};

const vercelPostgresStore = await VercelPostgres.initialize(
new CohereEmbeddings(),
new CohereEmbeddings({model: "embed-english-v3.0",}),
config
);

Expand Down
1 change: 1 addition & 0 deletions examples/src/models/embeddings/cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { CohereEmbeddings } from "@langchain/cohere";
const embeddings = new CohereEmbeddings({
apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.COHERE_API_KEY
batchSize: 48, // Default value if omitted is 48. Max value is 96
model: "embed-english-v3.0",
});
const res = await embeddings.embedQuery("Hello world");
console.log(res);
Expand Down
2 changes: 1 addition & 1 deletion examples/src/retrievers/multi_query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const vectorstore = await MemoryVectorStore.fromTexts(
"mitochondria is made of lipids",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({model: "embed-english-v3.0",})
);
const model = new ChatAnthropic({});
const retriever = MultiQueryRetriever.fromLLM({
Expand Down
4 changes: 2 additions & 2 deletions examples/src/retrievers/multi_query_custom.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { CohereEmbeddings } from "@langchain/community/embeddings/cohere";
import { CohereEmbeddings } from "@langchain/cohere";
import { MultiQueryRetriever } from "langchain/retrievers/multi_query";
import { LLMChain } from "langchain/chains";
import { pull } from "langchain/hub";
Expand Down Expand Up @@ -53,7 +53,7 @@ const vectorstore = await MemoryVectorStore.fromTexts(
"Mitochondrien bestehen aus Lipiden",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({model: "embed-english-v3.0",})
);
const model = new ChatAnthropic({});
const llmChain = new LLMChain({
Expand Down
10 changes: 5 additions & 5 deletions langchain/src/retrievers/tests/ensemble_retriever.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ test("Should work with a question input", async () => {
"mitochondria is made of lipids",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({model: "embed-english-v3.0",})
);
const retriever = new EnsembleRetriever({
retrievers: [vectorstore.asRetriever()],
Expand All @@ -38,7 +38,7 @@ test("Should work with multiple retriever", async () => {
"mitochondria is made of lipids",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({model: "embed-english-v3.0",})
);
const vectorstore2 = await MemoryVectorStore.fromTexts(
[
Expand All @@ -51,7 +51,7 @@ test("Should work with multiple retriever", async () => {
"mitochondria is made of lipids",
],
[{ id: 6 }, { id: 7 }, { id: 8 }, { id: 9 }, { id: 10 }],
new CohereEmbeddings()
new CohereEmbeddings({model: "embed-english-v3.0",})
);
const retriever = new EnsembleRetriever({
retrievers: [vectorstore.asRetriever(), vectorstore2.asRetriever()],
Expand All @@ -76,7 +76,7 @@ test("Should work with weights", async () => {
"mitochondria is made of lipids",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({model: "embed-english-v3.0",})
);
const vectorstore2 = await MemoryVectorStore.fromTexts(
[
Expand All @@ -89,7 +89,7 @@ test("Should work with weights", async () => {
"mitochondria is made of lipids",
],
[{ id: 6 }, { id: 7 }, { id: 8 }, { id: 9 }, { id: 10 }],
new CohereEmbeddings()
new CohereEmbeddings({model: "embed-english-v3.0",})
);
const retriever = new EnsembleRetriever({
retrievers: [vectorstore.asRetriever(), vectorstore2.asRetriever()],
Expand Down
4 changes: 2 additions & 2 deletions langchain/src/retrievers/tests/multi_query.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ test("Should work with a question input", async () => {
"mitochondria is made of lipids",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({model: "embed-english-v3.0",})
);
const model = new ChatOpenAI({});
const retriever = MultiQueryRetriever.fromLLM({
Expand All @@ -42,7 +42,7 @@ test("Should work with a keyword", async () => {
"mitochondria is made of lipids",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new CohereEmbeddings()
new CohereEmbeddings({model: "embed-english-v3.0",})
);
const model = new ChatOpenAI({});
const retriever = MultiQueryRetriever.fromLLM({
Expand Down
48 changes: 42 additions & 6 deletions libs/langchain-cohere/src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@ import { chunkArray } from "@langchain/core/utils/chunk_array";
* parameters specific to the CohereEmbeddings class.
*/
export interface CohereEmbeddingsParams extends EmbeddingsParams {
model: string;
model?: string;

/**
* The maximum number of documents to embed in a single request. This is
* limited by the Cohere API to a maximum of 96.
*/
batchSize?: number;

/**
* Specifies the type of embeddings you want to generate.
*/
embeddingTypes?: Array<string>;

/**
* Specifies the type of input you're giving to the model.
* Not required for older versions of the embedding models (i.e. anything lower than v3),
Expand All @@ -37,11 +42,11 @@ export class CohereEmbeddings
extends Embeddings
implements CohereEmbeddingsParams
{
model = "small";
model: string | undefined;

batchSize = 48;

inputType: string | undefined;
embeddingTypes = ["float"];

private client: CohereClient;

Expand Down Expand Up @@ -70,8 +75,16 @@ export class CohereEmbeddings
token: apiKey,
});
this.model = fieldsWithDefaults?.model ?? this.model;

if (!this.model) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer to just make this a required param and ship a new minor version release along with your tool calling stuff

throw new Error(
"Model not specified for CohereEmbeddings instance. Please provide a model name from the options here: https://docs.cohere.com/reference/embed"
);
}

this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize;
this.inputType = fieldsWithDefaults?.inputType;
this.embeddingTypes =
fieldsWithDefaults?.embeddingTypes ?? this.embeddingTypes;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to remove inputType from CohereEmbeddingsParams as well?

}

/**
Expand All @@ -87,7 +100,9 @@ export class CohereEmbeddings
model: this.model,
texts: batch,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
inputType: this.inputType as any,
inputType: "search_document" as any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
embeddingTypes: this.embeddingTypes as any,
})
);

Expand Down Expand Up @@ -120,7 +135,9 @@ export class CohereEmbeddings
model: this.model,
texts: [text],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
inputType: this.inputType as any,
inputType: "search_query" as any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
embeddingTypes: this.embeddingTypes as any,
});
if ("float" in embeddings && embeddings.float) {
return embeddings.float[0];
Expand All @@ -137,6 +154,25 @@ export class CohereEmbeddings
}
}

async embed(
jacoblee93 marked this conversation as resolved.
Show resolved Hide resolved
request: Parameters<typeof this.client.embed>[0]
): Promise<number[]> {
const { embeddings } = await this.embeddingWithRetry(request);
if ("float" in embeddings && embeddings.float) {
return embeddings.float[0];
} else if (Array.isArray(embeddings)) {
return embeddings[0];
} else {
throw new Error(
`Invalid response from Cohere API. Received: ${JSON.stringify(
embeddings,
null,
2
)}`
);
}
}

/**
* Generates embeddings with retry capabilities.
* @param request - An object containing the request parameters for generating embeddings.
Expand Down
7 changes: 6 additions & 1 deletion libs/langchain-cohere/src/rerank.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export interface CohereRerankArgs {
* Document compressor that uses `Cohere Rerank API`.
*/
export class CohereRerank extends BaseDocumentCompressor {
model = "rerank-english-v2.0";
model: string | undefined;

topN = 3;

Expand All @@ -48,6 +48,11 @@ export class CohereRerank extends BaseDocumentCompressor {
token,
});
this.model = fields?.model ?? this.model;
if (!this.model) {
throw new Error(
"Model not specified for CohereRerank instance. Please provide a model name from the options here: https://docs.cohere.com/reference/rerank"
);
}
this.topN = fields?.topN ?? this.topN;
this.maxChunksPerDoc = fields?.maxChunksPerDoc;
}
Expand Down
5 changes: 3 additions & 2 deletions libs/langchain-cohere/src/tests/embeddings.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import { test, expect } from "@jest/globals";
import { CohereEmbeddings } from "../embeddings.js";

test("Test CohereEmbeddings.embedQuery", async () => {
const embeddings = new CohereEmbeddings();
const embeddings = new CohereEmbeddings({model: "small"});
const res = await embeddings.embedQuery("Hello world");
expect(typeof res[0]).toBe("number");
});

test("Test CohereEmbeddings.embedDocuments", async () => {
const embeddings = new CohereEmbeddings();
const embeddings = new CohereEmbeddings({model: "small"});
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]);
expect(res).toHaveLength(2);
expect(typeof res[0][0]).toBe("number");
Expand All @@ -19,6 +19,7 @@ test("Test CohereEmbeddings concurrency", async () => {
const embeddings = new CohereEmbeddings({
batchSize: 1,
maxConcurrency: 2,
model: "small",
});
const res = await embeddings.embedDocuments([
"Hello world",
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain-cohere/src/tests/rerank.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const documents = [
test("CohereRerank can indeed rerank documents with compressDocuments method", async () => {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I've reviewed the code and noticed that the addition in this PR references an environment variable via process.env. I've flagged this for your review to ensure it aligns with the project's requirements. Let me know if you have any questions!

const cohereRerank = new CohereRerank({
apiKey: process.env.COHERE_API_KEY,
model: "rerank-english-v2.0"
});

const rerankedDocuments = await cohereRerank.compressDocuments(
Expand All @@ -33,6 +34,7 @@ test("CohereRerank can indeed rerank documents with compressDocuments method", a
test("CohereRerank can indeed rerank documents with rerank method", async () => {
const cohereRerank = new CohereRerank({
apiKey: process.env.COHERE_API_KEY,
model: "rerank-english-v2.0",
});

const rerankedDocuments = await cohereRerank.rerank(
Expand Down