Skip to content

Commit

Permalink
Update for AWS Bedrock knowledge bases to support filters and overrid…
Browse files Browse the repository at this point in the history
…eSearchType, Update the KB to support other locations as sources
  • Loading branch information
jl4nz committed Jul 24, 2024
1 parent a8e74c1 commit 1404550
Show file tree
Hide file tree
Showing 4 changed files with 931 additions and 16 deletions.
6 changes: 3 additions & 3 deletions libs/langchain-aws/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@aws-sdk/client-bedrock-agent-runtime": "^3.583.0",
"@aws-sdk/client-bedrock-agent-runtime": "^3.616.0",
"@aws-sdk/client-bedrock-runtime": "^3.602.0",
"@aws-sdk/client-kendra": "^3.352.0",
"@aws-sdk/credential-provider-node": "^3.600.0",
"@langchain/core": ">=0.2.16 <0.3.0",
"zod-to-json-schema": "^3.22.5"
},
"devDependencies": {
"@aws-sdk/types": "^3.598.0",
"@aws-sdk/types": "^3.609.0",
"@jest/globals": "^29.5.0",
"@langchain/scripts": "~0.0.14",
"@langchain/standard-tests": "0.0.0",
Expand Down Expand Up @@ -97,4 +97,4 @@
"index.d.ts",
"index.d.cts"
]
}
}
68 changes: 58 additions & 10 deletions libs/langchain-aws/src/retrievers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import {
RetrieveCommand,
BedrockAgentRuntimeClient,
type BedrockAgentRuntimeClientConfig,
type SearchType,
type RetrievalFilter,
} from "@aws-sdk/client-bedrock-agent-runtime";

import { BaseRetriever } from "@langchain/core/retrievers";
Expand All @@ -16,6 +18,8 @@ export interface AmazonKnowledgeBaseRetrieverArgs {
topK: number;
region: string;
clientOptions?: BedrockAgentRuntimeClientConfig;
filter?: RetrievalFilter;
overrideSearchType?: SearchType;
}

/**
Expand Down Expand Up @@ -51,15 +55,23 @@ export class AmazonKnowledgeBaseRetriever extends BaseRetriever {

bedrockAgentRuntimeClient: BedrockAgentRuntimeClient;

filter?: RetrievalFilter;

overrideSearchType?: SearchType;

constructor({
knowledgeBaseId,
topK = 10,
clientOptions,
region,
filter,
overrideSearchType,
}: AmazonKnowledgeBaseRetrieverArgs) {
super();

this.topK = topK;
this.filter = filter;
this.overrideSearchType = overrideSearchType;
this.bedrockAgentRuntimeClient = new BedrockAgentRuntimeClient({
region,
...clientOptions,
Expand All @@ -78,7 +90,12 @@ export class AmazonKnowledgeBaseRetriever extends BaseRetriever {
return res;
}

async queryKnowledgeBase(query: string, topK: number) {
async queryKnowledgeBase(
query: string,
topK: number,
filter?: RetrievalFilter,
overrideSearchType?: SearchType
) {
const retrieveCommand = new RetrieveCommand({
knowledgeBaseId: this.knowledgeBaseId,
retrievalQuery: {
Expand All @@ -87,6 +104,8 @@ export class AmazonKnowledgeBaseRetriever extends BaseRetriever {
retrievalConfiguration: {
vectorSearchConfiguration: {
numberOfResults: topK,
overrideSearchType,
filter,
},
},
});
Expand All @@ -96,19 +115,48 @@ export class AmazonKnowledgeBaseRetriever extends BaseRetriever {
);

return (
retrieveResponse.retrievalResults?.map((result) => ({
pageContent: this.cleanResult(result.content?.text || ""),
metadata: {
source: result.location?.s3Location?.uri,
score: result.score,
...result.metadata,
},
})) ?? ([] as Array<Document>)
retrieveResponse.retrievalResults?.map((result) => {
let source;
switch (result.location?.type) {
case "CONFLUENCE":
source = result.location?.confluenceLocation?.url;
break;
case "S3":
source = result.location?.s3Location?.uri;
break;
case "SALESFORCE":
source = result.location?.salesforceLocation?.url;
break;
case "SHAREPOINT":
source = result.location?.sharePointLocation?.url;
break;
case "WEB":
source = result.location?.webLocation?.url;
break;
default:
source = result.location?.s3Location?.uri;
break;
}

return {
pageContent: this.cleanResult(result.content?.text || ""),
metadata: {
source,
score: result.score,
...result.metadata,
},
};
}) ?? ([] as Array<Document>)
);
}

async _getRelevantDocuments(query: string): Promise<Document[]> {
const docs = await this.queryKnowledgeBase(query, this.topK);
const docs = await this.queryKnowledgeBase(
query,
this.topK,
this.filter,
this.overrideSearchType
);
return docs;
}
}
2 changes: 2 additions & 0 deletions libs/langchain-aws/src/retrievers/tests/bedrock.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ test("AmazonKnowledgeBaseRetriever", async () => {
topK: 10,
knowledgeBaseId: process.env.AMAZON_KNOWLEDGE_BASE_ID || "",
region: process.env.BEDROCK_AWS_REGION,
overrideSearchType: "HYBRID",
filter: undefined,
clientOptions: {
credentials: {
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID,
Expand Down
Loading

0 comments on commit 1404550

Please sign in to comment.