Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: batch embedding queries, fix get client in tests #2908

Merged
merged 11 commits into from
Oct 13, 2023
74 changes: 47 additions & 27 deletions langchain/src/embeddings/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import {
BedrockRuntimeClient,
InvokeModelCommand,
} from "@aws-sdk/client-bedrock-runtime";

import { Embeddings, EmbeddingsParams } from "./base.js";
import type { CredentialType } from "../util/bedrock.js";

Expand Down Expand Up @@ -40,6 +39,8 @@ export class BedrockEmbeddings

client: BedrockRuntimeClient;

batchSize = 512;

constructor(fields?: BedrockEmbeddingsParams) {
super(fields ?? {});

Expand All @@ -53,28 +54,48 @@ export class BedrockEmbeddings
});
}

/**
* Protected method to make a request to the Bedrock API to generate
* embeddings. Handles the retry logic and returns the response from the
* API.
* @param request Request to send to the Bedrock API.
* @returns Promise that resolves to the response from the API.
*/
protected async _embedText(text: string): Promise<number[]> {
// replace newlines, which can negatively affect performance.
const cleanedText = text.replace(/\n/g, " ");

const res = await this.client.send(
new InvokeModelCommand({
modelId: this.model,
body: JSON.stringify({
inputText: cleanedText,
}),
contentType: "application/json",
accept: "application/json",
})
);

try {
const body = new TextDecoder().decode(res.body);

return JSON.parse(body).embedding;
} catch (e) {
throw new Error("An invalid response was returned by Bedrock.");
}
return this.caller.call(async () => {
try {
// replace newlines, which can negatively affect performance.
const cleanedText = text.replace(/\n/g, " ");

const res = await this.client.send(
new InvokeModelCommand({
modelId: this.model,
body: JSON.stringify({
inputText: cleanedText,
}),
contentType: "application/json",
accept: "application/json",
})
);

const body = new TextDecoder().decode(res.body);
return JSON.parse(body).embedding;
} catch (e) {
console.error({
error: e,
});
// eslint-disable-next-line no-instanceof/no-instanceof
if (e instanceof Error) {
throw new Error(
`An error occurred while embedding documents with Bedrock: ${e.message}`
);
}

throw new Error(
"An error occurred while embedding documents with Bedrock"
);
}
});
}

/**
Expand All @@ -93,13 +114,12 @@ export class BedrockEmbeddings
}

/**
* Method that takes an array of documents as input and returns a promise
* that resolves to a 2D array of embeddings for each document. It calls
* the _embedText method for each document in the array.
* @param documents Array of documents for which to generate embeddings.
* Method to generate embeddings for an array of texts. Calls _embedText
* method which batches and handles retry logic when calling the AWS Bedrock API.
* @param documents Array of texts for which to generate embeddings.
* @returns Promise that resolves to a 2D array of embeddings for each input document.
*/
embedDocuments(documents: string[]): Promise<number[][]> {
async embedDocuments(documents: string[]): Promise<number[][]> {
return Promise.all(documents.map((document) => this._embedText(document)));
}
}
39 changes: 25 additions & 14 deletions langchain/src/embeddings/tests/bedrock.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,42 @@ import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime";
import { HNSWLib } from "../../vectorstores/hnswlib.js";
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds a change that requires environment variables via process.env. Please review this change to ensure the proper handling and usage of environment variables.

import { BedrockEmbeddings } from "../bedrock.js";

const client = new BedrockRuntimeClient({
region: process.env.BEDROCK_AWS_REGION!,
credentials: {
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
},
});
const getClient = () => {
if (
!process.env.BEDROCK_AWS_REGION ||
!process.env.BEDROCK_AWS_ACCESS_KEY_ID ||
!process.env.BEDROCK_AWS_SECRET_ACCESS_KEY
) {
throw new Error("Missing environment variables for AWS");
}

const client = new BedrockRuntimeClient({
region: process.env.BEDROCK_AWS_REGION,
credentials: {
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID,
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY,
},
});

return client;
};

test("Test BedrockEmbeddings.embedQuery", async () => {
const client = getClient();
const embeddings = new BedrockEmbeddings({
maxRetries: 1,
client,
});
const res = await embeddings.embedQuery("Hello world");
console.log(res);
// console.log(res);
expect(typeof res[0]).toBe("number");
});

test("Test BedrockEmbeddings.embedDocuments with passed region and credentials", async () => {
const client = getClient();
const embeddings = new BedrockEmbeddings({
maxRetries: 1,
region: process.env.BEDROCK_AWS_REGION!,
credentials: {
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
},
client,
});
const res = await embeddings.embedDocuments([
"Hello world",
Expand All @@ -41,14 +51,15 @@ test("Test BedrockEmbeddings.embedDocuments with passed region and credentials",
"six documents",
"to test pagination",
]);
console.log(res);
// console.log(res);
expect(res).toHaveLength(6);
res.forEach((r) => {
expect(typeof r[0]).toBe("number");
});
});

test("Test end to end with HNSWLib", async () => {
const client = getClient();
const vectorStore = await HNSWLib.fromTexts(
["Hello world", "Bye bye", "hello nice world"],
[{ id: 2 }, { id: 1 }, { id: 3 }],
Expand Down