diff --git a/libs/langchain-community/src/chat_models/bedrock/web.ts b/libs/langchain-community/src/chat_models/bedrock/web.ts index d7811fac0089..b17f7feb34cc 100644 --- a/libs/langchain-community/src/chat_models/bedrock/web.ts +++ b/libs/langchain-community/src/chat_models/bedrock/web.ts @@ -11,7 +11,11 @@ import { LangSmithParams, BaseChatModelCallOptions, } from "@langchain/core/language_models/chat_models"; -import { BaseLanguageModelInput } from "@langchain/core/language_models/base"; +import { + BaseLanguageModelInput, + StructuredOutputMethodOptions, + ToolDefinition, +} from "@langchain/core/language_models/base"; import { Runnable } from "@langchain/core/runnables"; import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { @@ -32,12 +36,17 @@ import { isStructuredTool } from "@langchain/core/utils/function_calling"; import { ToolCall } from "@langchain/core/messages/tool"; import { zodToJsonSchema } from "zod-to-json-schema"; +import { isOpenAITool } from "@langchain/core/utils/is_openai_tool"; +import type { SerializedFields } from "../../load/map_keys.js"; import { BaseBedrockInput, BedrockLLMInputOutputAdapter, type CredentialType, } from "../../utils/bedrock/index.js"; -import type { SerializedFields } from "../../load/map_keys.js"; +import { isAnthropicTool } from "../../utils/bedrock/anthropic.js"; +import { type z } from "zod"; + +type AnthropicTool = Record; const PRELUDE_TOTAL_LENGTH_BYTES = 4; @@ -99,6 +108,49 @@ export function convertMessagesToPrompt( throw new Error(`Provider ${provider} does not support chat.`); } +function formatTools(tools: BedrockChatCallOptions["tools"]): AnthropicTool[] { + if (!tools || !tools.length) { + return []; + } + if (tools.every((tc) => isStructuredTool(tc))) { + return (tools as StructuredToolInterface[]).map((tc) => ({ + name: tc.name, + description: tc.description, + input_schema: zodToJsonSchema(tc.schema), + })); + } + if (tools.every((tc) => isOpenAITool(tc))) { + return (tools as ToolDefinition[]).map((tc) => ({ + name: tc.function.name, + description: tc.function.description, + input_schema: tc.function.parameters, + })); + } + + if (tools.every((tc) => isAnthropicTool(tc))) { + return tools as AnthropicTool[]; + } + + if ( + tools.some((tc) => isStructuredTool(tc)) || + tools.some((tc) => isOpenAITool(tc)) || + tools.some((tc) => isAnthropicTool(tc)) + ) { + throw new Error( + "All tools passed to BedrockChat must be of the same type." + ); + } + throw new Error("Invalid tool format received."); +} + +export interface BedrockChatCallOptions extends BaseChatModelCallOptions { + tools?: (StructuredToolInterface | AnthropicTool | ToolDefinition)[]; +} + +export interface BedrockChatFields + extends Partial, + BaseChatModelParams {} + /** * A type of Large Language Model (LLM) that interacts with the Bedrock * service. It extends the base `LLM` class and implements the @@ -195,7 +247,10 @@ export function convertMessagesToPrompt( * runStreaming().catch(console.error); * ``` */ -export class BedrockChat extends BaseChatModel implements BaseBedrockInput { +export class BedrockChat + extends BaseChatModel + implements BaseBedrockInput +{ model = "amazon.titan-tg1-large"; region: string; @@ -234,7 +289,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { streamProcessingMode: "SYNCHRONOUS" | "ASYNCHRONOUS"; }; - protected _anthropicTools?: Record[]; + protected _anthropicTools?: AnthropicTool[]; get lc_aliases(): Record { return { @@ -268,7 +323,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { return "BedrockChat"; } - constructor(fields?: Partial & BaseChatModelParams) { + constructor(fields?: BedrockChatFields) { super(fields ?? {}); this.model = fields?.model ?? this.model; @@ -318,11 +373,14 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { } override invocationParams(options?: this["ParsedCallOptions"]) { + const callOptionTools = formatTools(options?.tools ?? []); return { - tools: this._anthropicTools, + tools: [...(this._anthropicTools ?? []), ...callOptionTools], temperature: this.temperature, max_tokens: this.maxTokens, - stop: options?.stop, + stop: options?.stop ?? this.stopSequences, + modelKwargs: this.modelKwargs, + guardrailConfig: this.guardrailConfig, }; } @@ -340,7 +398,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { async _generate( messages: BaseMessage[], - options: Partial, + options: Partial, runManager?: CallbackManagerForLLMRun ): Promise { if (this.streaming) { @@ -368,7 +426,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { async _generateNonStreaming( messages: BaseMessage[], - options: Partial, + options: Partial, _runManager?: CallbackManagerForLLMRun ): Promise { const service = "bedrock-runtime"; @@ -412,26 +470,34 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { } ) { const { bedrockMethod, endpointHost, provider } = fields; + const { + max_tokens, + temperature, + stop, + modelKwargs, + guardrailConfig, + tools, + } = this.invocationParams(options); const inputBody = this.usesMessagesApi ? BedrockLLMInputOutputAdapter.prepareMessagesInput( provider, messages, - this.maxTokens, - this.temperature, - options.stop ?? this.stopSequences, - this.modelKwargs, - this.guardrailConfig, - this._anthropicTools + max_tokens, + temperature, + stop, + modelKwargs, + guardrailConfig, + tools ) : BedrockLLMInputOutputAdapter.prepareInput( provider, convertMessagesToPromptAnthropic(messages), - this.maxTokens, - this.temperature, - options.stop ?? this.stopSequences, - this.modelKwargs, + max_tokens, + temperature, + stop, + modelKwargs, fields.bedrockMethod, - this.guardrailConfig + guardrailConfig ); const url = new URL( @@ -680,12 +746,12 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { } override bindTools( - tools: (StructuredToolInterface | Record)[], - _kwargs?: Partial + tools: this["ParsedCallOptions"]["tools"], + _kwargs?: Partial ): Runnable< BaseLanguageModelInput, BaseMessageChunk, - BaseChatModelCallOptions + this["ParsedCallOptions"] > { const provider = this.model.split(".")[0]; if (provider !== "anthropic") { @@ -693,18 +759,64 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { "Currently, tool calling through Bedrock is only supported for Anthropic models." ); } - this._anthropicTools = tools.map((tool) => { - if (isStructuredTool(tool)) { - return { - name: tool.name, - description: tool.description, - input_schema: zodToJsonSchema(tool.schema), - }; - } - return tool; - }); + this._anthropicTools = formatTools(tools); return this; } + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record +>( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions +): Runnable; + +withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record +>( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions +): Runnable; + +withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record +>( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions +): + | Runnable + | Runnable< + BaseLanguageModelInput, + { raw: BaseMessage; parsed: RunOutput } + > { + if (!super.withStructuredOutput) { + throw new Error(`withStructuredOutput is not implemented in the base class. +This is likely due to an outdated version of "@langchain/core". +Please upgrade to the latest version.`) + } + if (config?.includeRaw) { + return super.withStructuredOutput(outputSchema, { + ...config, + includeRaw: true, + }); + } else { + return super.withStructuredOutput(outputSchema, { + ...config, + includeRaw: false, + }); + } + } } function isChatGenerationChunk( diff --git a/libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts b/libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts index a202570eb895..fb491b1a7dfa 100644 --- a/libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts @@ -8,6 +8,8 @@ import { AgentExecutor, createToolCallingAgent } from "langchain/agents"; import { ChatPromptTemplate } from "@langchain/core/prompts"; import { BedrockChat as BedrockChatWeb } from "../bedrock/web.js"; import { TavilySearchResults } from "../../tools/tavily_search.js"; +import { z } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; void testChatModel( "Test Bedrock chat model Generating search queries: Command-r", @@ -383,3 +385,103 @@ test.skip.each([ expect(res.content.length).toBeGreaterThan(1); }); + +test.skip("withStructuredOutput", async () => { + const weatherTool = z + .object({ + city: z.string().describe("The city to get the weather for"), + state: z.string().describe("The state to get the weather for").optional(), + }) + .describe("Get the weather for a city"); + const model = new BedrockChatWeb({ + region: process.env.BEDROCK_AWS_REGION, + model: "anthropic.claude-3-sonnet-20240229-v1:0", + maxRetries: 0, + credentials: { + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + }, + }); + const modelWithTools = model.withStructuredOutput(weatherTool, { + name: "weather", + }); + const response = await modelWithTools.invoke( + "Whats the weather like in san francisco?" + ); + expect(response.city.toLowerCase()).toBe("san francisco"); +}); + +test.skip(".bindTools", async () => { + const weatherTool = z + .object({ + city: z.string().describe("The city to get the weather for"), + state: z.string().describe("The state to get the weather for").optional(), + }) + .describe("Get the weather for a city"); + const model = new BedrockChatWeb({ + region: process.env.BEDROCK_AWS_REGION, + model: "anthropic.claude-3-sonnet-20240229-v1:0", + maxRetries: 0, + credentials: { + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + }, + }); + const modelWithTools = model.bind({ + tools: [ + { + name: "weather_tool", + description: weatherTool.description, + input_schema: zodToJsonSchema(weatherTool), + }, + ], + }); + const response = await modelWithTools.invoke( + "Whats the weather like in san francisco?" + ); + console.log(response); + if (!response.tool_calls?.[0]) { + throw new Error("No tool calls found in response"); + } + const { tool_calls } = response; + expect(tool_calls[0].name.toLowerCase()).toBe("weather_tool"); +}); + +test.skip(".bindTools with openai tool format", async () => { + const weatherTool = z + .object({ + city: z.string().describe("The city to get the weather for"), + state: z.string().describe("The state to get the weather for").optional(), + }) + .describe("Get the weather for a city"); + const model = new BedrockChatWeb({ + region: process.env.BEDROCK_AWS_REGION, + model: "anthropic.claude-3-sonnet-20240229-v1:0", + maxRetries: 0, + credentials: { + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + }, + }); + const modelWithTools = model.bind({ + tools: [ + { + type: "function", + function: { + name: "weather_tool", + description: weatherTool.description, + parameters: zodToJsonSchema(weatherTool), + } + }, + ], + }); + const response = await modelWithTools.invoke( + "Whats the weather like in san francisco?" + ); + console.log(response); + if (!response.tool_calls?.[0]) { + throw new Error("No tool calls found in response"); + } + const { tool_calls } = response; + expect(tool_calls[0].name.toLowerCase()).toBe("weather_tool"); +}); \ No newline at end of file diff --git a/libs/langchain-community/src/chat_models/tests/chatbedrock.standard.int.test.ts b/libs/langchain-community/src/chat_models/tests/chatbedrock.standard.int.test.ts index 6c796e8d2780..466836d4609f 100644 --- a/libs/langchain-community/src/chat_models/tests/chatbedrock.standard.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/chatbedrock.standard.int.test.ts @@ -14,7 +14,7 @@ class BedrockChatStandardIntegrationTests extends ChatModelIntegrationTests< super({ Cls: BedrockChat, chatModelHasToolCalling: true, - chatModelHasStructuredOutput: false, + chatModelHasStructuredOutput: true, constructorArgs: { region, model: "anthropic.claude-3-sonnet-20240229-v1:0", diff --git a/libs/langchain-community/src/chat_models/tests/chatbedrock.standard.test.ts b/libs/langchain-community/src/chat_models/tests/chatbedrock.standard.test.ts index f2ce23db39ba..55c492777c0d 100644 --- a/libs/langchain-community/src/chat_models/tests/chatbedrock.standard.test.ts +++ b/libs/langchain-community/src/chat_models/tests/chatbedrock.standard.test.ts @@ -12,9 +12,11 @@ class BedrockChatStandardUnitTests extends ChatModelUnitTests< constructor() { super({ Cls: BedrockChat, - chatModelHasToolCalling: false, - chatModelHasStructuredOutput: false, - constructorArgs: {}, + chatModelHasToolCalling: true, + chatModelHasStructuredOutput: true, + constructorArgs: { + model: "anthropic.claude-3-sonnet-20240229-v1:0", + }, }); process.env.BEDROCK_AWS_SECRET_ACCESS_KEY = "test"; process.env.BEDROCK_AWS_ACCESS_KEY_ID = "test"; diff --git a/libs/langchain-community/src/utils/bedrock/anthropic.ts b/libs/langchain-community/src/utils/bedrock/anthropic.ts index f6d37dc2e018..4bb888a9de77 100644 --- a/libs/langchain-community/src/utils/bedrock/anthropic.ts +++ b/libs/langchain-community/src/utils/bedrock/anthropic.ts @@ -219,3 +219,11 @@ export function formatMessagesForAnthropic(messages: BaseMessage[]): { system, }; } + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function isAnthropicTool( + tool: unknown +): tool is Record { + if (typeof tool !== "object" || !tool) return false; + return "input_schema" in tool; +}