Skip to content

Commit

Permalink
add oai and withstructuredoutput support to bedrock
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jun 12, 2024
1 parent ffedf9a commit 21a2a36
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 37 deletions.
178 changes: 145 additions & 33 deletions libs/langchain-community/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ import {
LangSmithParams,
BaseChatModelCallOptions,
} from "@langchain/core/language_models/chat_models";
import { BaseLanguageModelInput } from "@langchain/core/language_models/base";
import {
BaseLanguageModelInput,
StructuredOutputMethodOptions,
ToolDefinition,
} from "@langchain/core/language_models/base";
import { Runnable } from "@langchain/core/runnables";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import {
Expand All @@ -32,12 +36,17 @@ import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { ToolCall } from "@langchain/core/messages/tool";
import { zodToJsonSchema } from "zod-to-json-schema";

import { isOpenAITool } from "@langchain/core/utils/is_openai_tool";
import type { SerializedFields } from "../../load/map_keys.js";
import {
BaseBedrockInput,
BedrockLLMInputOutputAdapter,
type CredentialType,
} from "../../utils/bedrock/index.js";
import type { SerializedFields } from "../../load/map_keys.js";
import { isAnthropicTool } from "../../utils/bedrock/anthropic.js";
import { type z } from "zod";

Check failure on line 47 in libs/langchain-community/src/chat_models/bedrock/web.ts

View workflow job for this annotation

GitHub Actions / Check linting

`zod` import should occur before type import of `../../load/map_keys.js`

type AnthropicTool = Record<string, unknown>;

const PRELUDE_TOTAL_LENGTH_BYTES = 4;

Expand Down Expand Up @@ -99,6 +108,49 @@ export function convertMessagesToPrompt(
throw new Error(`Provider ${provider} does not support chat.`);
}

function formatTools(tools: BedrockChatCallOptions["tools"]): AnthropicTool[] {
if (!tools || !tools.length) {
return [];
}
if (tools.every((tc) => isStructuredTool(tc))) {
return (tools as StructuredToolInterface[]).map((tc) => ({
name: tc.name,
description: tc.description,
input_schema: zodToJsonSchema(tc.schema),
}));
}
if (tools.every((tc) => isOpenAITool(tc))) {
return (tools as ToolDefinition[]).map((tc) => ({
name: tc.function.name,
description: tc.function.description,
input_schema: tc.function.parameters,
}));
}

if (tools.every((tc) => isAnthropicTool(tc))) {
return tools as AnthropicTool[];
}

if (
tools.some((tc) => isStructuredTool(tc)) ||
tools.some((tc) => isOpenAITool(tc)) ||
tools.some((tc) => isAnthropicTool(tc))
) {
throw new Error(
"All tools passed to BedrockChat must be of the same type."
);
}
throw new Error("Invalid tool format received.");
}

export interface BedrockChatCallOptions extends BaseChatModelCallOptions {
tools?: (StructuredToolInterface | AnthropicTool | ToolDefinition)[];
}

export interface BedrockChatFields
extends Partial<BaseBedrockInput>,
BaseChatModelParams {}

/**
* A type of Large Language Model (LLM) that interacts with the Bedrock
* service. It extends the base `LLM` class and implements the
Expand Down Expand Up @@ -195,7 +247,10 @@ export function convertMessagesToPrompt(
* runStreaming().catch(console.error);
* ```
*/
export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
export class BedrockChat
extends BaseChatModel<BedrockChatCallOptions, AIMessageChunk>
implements BaseBedrockInput
{
model = "amazon.titan-tg1-large";

region: string;
Expand Down Expand Up @@ -234,7 +289,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
streamProcessingMode: "SYNCHRONOUS" | "ASYNCHRONOUS";
};

protected _anthropicTools?: Record<string, unknown>[];
protected _anthropicTools?: AnthropicTool[];

get lc_aliases(): Record<string, string> {
return {
Expand Down Expand Up @@ -268,7 +323,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
return "BedrockChat";
}

constructor(fields?: Partial<BaseBedrockInput> & BaseChatModelParams) {
constructor(fields?: BedrockChatFields) {
super(fields ?? {});

this.model = fields?.model ?? this.model;
Expand Down Expand Up @@ -318,11 +373,14 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
}

override invocationParams(options?: this["ParsedCallOptions"]) {
const callOptionTools = formatTools(options?.tools ?? []);
return {
tools: this._anthropicTools,
tools: [...(this._anthropicTools ?? []), ...callOptionTools],
temperature: this.temperature,
max_tokens: this.maxTokens,
stop: options?.stop,
stop: options?.stop ?? this.stopSequences,
modelKwargs: this.modelKwargs,
guardrailConfig: this.guardrailConfig,
};
}

Expand All @@ -340,7 +398,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {

async _generate(
messages: BaseMessage[],
options: Partial<BaseChatModelParams>,
options: Partial<this["ParsedCallOptions"]>,
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (this.streaming) {
Expand Down Expand Up @@ -368,7 +426,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {

async _generateNonStreaming(
messages: BaseMessage[],
options: Partial<BaseChatModelParams>,
options: Partial<this["ParsedCallOptions"]>,
_runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
const service = "bedrock-runtime";
Expand Down Expand Up @@ -412,26 +470,34 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
}
) {
const { bedrockMethod, endpointHost, provider } = fields;
const {
max_tokens,
temperature,
stop,
modelKwargs,
guardrailConfig,
tools,
} = this.invocationParams(options);
const inputBody = this.usesMessagesApi
? BedrockLLMInputOutputAdapter.prepareMessagesInput(
provider,
messages,
this.maxTokens,
this.temperature,
options.stop ?? this.stopSequences,
this.modelKwargs,
this.guardrailConfig,
this._anthropicTools
max_tokens,
temperature,
stop,
modelKwargs,
guardrailConfig,
tools
)
: BedrockLLMInputOutputAdapter.prepareInput(
provider,
convertMessagesToPromptAnthropic(messages),
this.maxTokens,
this.temperature,
options.stop ?? this.stopSequences,
this.modelKwargs,
max_tokens,
temperature,
stop,
modelKwargs,
fields.bedrockMethod,
this.guardrailConfig
guardrailConfig
);

const url = new URL(
Expand Down Expand Up @@ -680,31 +746,77 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
}

override bindTools(
tools: (StructuredToolInterface | Record<string, unknown>)[],
_kwargs?: Partial<BaseChatModelCallOptions>
tools: this["ParsedCallOptions"]["tools"],
_kwargs?: Partial<this["ParsedCallOptions"]>
): Runnable<
BaseLanguageModelInput,
BaseMessageChunk,
BaseChatModelCallOptions
this["ParsedCallOptions"]
> {
const provider = this.model.split(".")[0];
if (provider !== "anthropic") {
throw new Error(
"Currently, tool calling through Bedrock is only supported for Anthropic models."
);
}
this._anthropicTools = tools.map((tool) => {
if (isStructuredTool(tool)) {
return {
name: tool.name,
description: tool.description,
input_schema: zodToJsonSchema(tool.schema),
};
}
return tool;
});
this._anthropicTools = formatTools(tools);
return this;
}

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<false>
): Runnable<BaseLanguageModelInput, RunOutput>;

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<true>
): Runnable<BaseLanguageModelInput, { raw: BaseMessage; parsed: RunOutput }>;

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<boolean>
):
| Runnable<BaseLanguageModelInput, RunOutput>
| Runnable<
BaseLanguageModelInput,
{ raw: BaseMessage; parsed: RunOutput }
> {
if (!super.withStructuredOutput) {
throw new Error(`withStructuredOutput is not implemented in the base class.
This is likely due to an outdated version of "@langchain/core".
Please upgrade to the latest version.`)
}
if (config?.includeRaw) {
return super.withStructuredOutput(outputSchema, {
...config,
includeRaw: true,
});
} else {
return super.withStructuredOutput(outputSchema, {
...config,
includeRaw: false,
});
}
}
}

function isChatGenerationChunk(
Expand Down
102 changes: 102 additions & 0 deletions libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import { AgentExecutor, createToolCallingAgent } from "langchain/agents";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { BedrockChat as BedrockChatWeb } from "../bedrock/web.js";
import { TavilySearchResults } from "../../tools/tavily_search.js";
import { z } from "zod";

Check failure on line 11 in libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts

View workflow job for this annotation

GitHub Actions / Check linting

`zod` import should occur before import of `../bedrock/web.js`
import { zodToJsonSchema } from "zod-to-json-schema";

Check failure on line 12 in libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts

View workflow job for this annotation

GitHub Actions / Check linting

`zod-to-json-schema` import should occur before import of `../bedrock/web.js`

void testChatModel(
"Test Bedrock chat model Generating search queries: Command-r",
Expand Down Expand Up @@ -383,3 +385,103 @@ test.skip.each([

expect(res.content.length).toBeGreaterThan(1);
});

test.skip("withStructuredOutput", async () => {
const weatherTool = z
.object({
city: z.string().describe("The city to get the weather for"),
state: z.string().describe("The state to get the weather for").optional(),
})
.describe("Get the weather for a city");
const model = new BedrockChatWeb({
region: process.env.BEDROCK_AWS_REGION,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});
const modelWithTools = model.withStructuredOutput(weatherTool, {
name: "weather",
});
const response = await modelWithTools.invoke(
"Whats the weather like in san francisco?"
);
expect(response.city.toLowerCase()).toBe("san francisco");
});

test.skip(".bindTools", async () => {
const weatherTool = z
.object({
city: z.string().describe("The city to get the weather for"),
state: z.string().describe("The state to get the weather for").optional(),
})
.describe("Get the weather for a city");
const model = new BedrockChatWeb({
region: process.env.BEDROCK_AWS_REGION,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});
const modelWithTools = model.bind({
tools: [
{
name: "weather_tool",
description: weatherTool.description,
input_schema: zodToJsonSchema(weatherTool),
},
],
});
const response = await modelWithTools.invoke(
"Whats the weather like in san francisco?"
);
console.log(response);
if (!response.tool_calls?.[0]) {
throw new Error("No tool calls found in response");
}
const { tool_calls } = response;
expect(tool_calls[0].name.toLowerCase()).toBe("weather_tool");
});

test.skip(".bindTools with openai tool format", async () => {
const weatherTool = z
.object({
city: z.string().describe("The city to get the weather for"),
state: z.string().describe("The state to get the weather for").optional(),
})
.describe("Get the weather for a city");
const model = new BedrockChatWeb({
region: process.env.BEDROCK_AWS_REGION,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});
const modelWithTools = model.bind({
tools: [
{
type: "function",
function: {
name: "weather_tool",
description: weatherTool.description,
parameters: zodToJsonSchema(weatherTool),
}
},
],
});
const response = await modelWithTools.invoke(
"Whats the weather like in san francisco?"
);
console.log(response);
if (!response.tool_calls?.[0]) {
throw new Error("No tool calls found in response");
}
const { tool_calls } = response;
expect(tool_calls[0].name.toLowerCase()).toBe("weather_tool");
});
Loading

0 comments on commit 21a2a36

Please sign in to comment.