Skip to content

Commit

Permalink
Delegate streaming tool calls to invoke for now
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Jun 7, 2024
1 parent 208bba3 commit 8f71915
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 105 deletions.
267 changes: 175 additions & 92 deletions libs/langchain-community/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ import {
AIMessage,
ChatMessage,
BaseMessageChunk,
isAIMessage,
} from "@langchain/core/messages";
import {
ChatGeneration,
ChatGenerationChunk,
ChatResult,
} from "@langchain/core/outputs";
import { StructuredToolInterface } from "@langchain/core/tools";
import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { ToolCall } from "@langchain/core/messages/tool";
import zodToJsonSchema from "zod-to-json-schema";

import {
BaseBedrockInput,
Expand Down Expand Up @@ -230,7 +234,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
streamProcessingMode: "SYNCHRONOUS" | "ASYNCHRONOUS";
};

tools: (StructuredToolInterface | Record<string, unknown>)[] = [];
protected _anthropicTools?: Record<string, unknown>[];

get lc_aliases(): Record<string, string> {
return {
Expand Down Expand Up @@ -313,6 +317,15 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
this.guardrailConfig = fields?.guardrailConfig;
}

override invocationParams(options?: this["ParsedCallOptions"]) {
return {
tools: this._anthropicTools,
temperature: this.temperature,
max_tokens: this.maxTokens,
stop: options?.stop,
};
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
const params = this.invocationParams(options);
return {
Expand All @@ -330,10 +343,6 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
options: Partial<BaseChatModelParams>,
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
const service = "bedrock-runtime";
const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;
const provider = this.model.split(".")[0];
if (this.streaming) {
const stream = this._streamResponseChunks(messages, options, runManager);
let finalResult: ChatGenerationChunk | undefined;
Expand All @@ -354,7 +363,18 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
llmOutput: finalResult.generationInfo,
};
}
return this._generateNonStreaming(messages, options, runManager);
}

async _generateNonStreaming(
messages: BaseMessage[],
options: Partial<BaseChatModelParams>,
_runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
const service = "bedrock-runtime";
const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;
const provider = this.model.split(".")[0];
const response = await this._signedFetch(messages, options, {
bedrockMethod: "invoke",
endpointHost,
Expand Down Expand Up @@ -401,7 +421,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
options.stop ?? this.stopSequences,
this.modelKwargs,
this.guardrailConfig,
this.tools
this._anthropicTools
)
: BedrockLLMInputOutputAdapter.prepareInput(
provider,
Expand Down Expand Up @@ -467,97 +487,145 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const provider = this.model.split(".")[0];
const service = "bedrock-runtime";

const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;

const bedrockMethod =
provider === "anthropic" ||
provider === "cohere" ||
provider === "meta" ||
provider === "mistral"
? "invoke-with-response-stream"
: "invoke";

const response = await this._signedFetch(messages, options, {
bedrockMethod,
endpointHost,
provider,
});

if (response.status < 200 || response.status >= 300) {
throw Error(
`Failed to access underlying url '${endpointHost}': got ${
response.status
} ${response.statusText}: ${await response.text()}`
if (this._anthropicTools) {
const { generations } = await this._generateNonStreaming(
messages,
options
);
}
const result = generations[0].message as AIMessage;
const toolCallChunks = result.tool_calls?.map(
(toolCall: ToolCall, index: number) => ({
name: toolCall.name,
args: JSON.stringify(toolCall.args),
id: toolCall.id,
index,
})
);
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: result.content,
additional_kwargs: result.additional_kwargs,
tool_call_chunks: toolCallChunks,
}),
text: generations[0].text,
});
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(generations[0].text);
} else {
const provider = this.model.split(".")[0];
const service = "bedrock-runtime";

const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;

const bedrockMethod =
provider === "anthropic" ||
provider === "cohere" ||
provider === "meta" ||
provider === "mistral"
? "invoke-with-response-stream"
: "invoke";

const response = await this._signedFetch(messages, options, {
bedrockMethod,
endpointHost,
provider,
});

if (
provider === "anthropic" ||
provider === "cohere" ||
provider === "meta" ||
provider === "mistral"
) {
const reader = response.body?.getReader();
const decoder = new TextDecoder();
for await (const chunk of this._readChunks(reader)) {
const event = this.codec.decode(chunk);
if (
(event.headers[":event-type"] !== undefined &&
event.headers[":event-type"].value !== "chunk") ||
event.headers[":content-type"].value !== "application/json"
) {
throw Error(`Failed to get event chunk: got ${chunk}`);
}
const body = JSON.parse(decoder.decode(event.body));
if (body.message) {
throw new Error(body.message);
}
if (body.bytes !== undefined) {
const chunkResult = JSON.parse(
decoder.decode(
Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0)
)
);
if (this.usesMessagesApi) {
const chunk = BedrockLLMInputOutputAdapter.prepareMessagesOutput(
provider,
chunkResult
if (response.status < 200 || response.status >= 300) {
throw Error(
`Failed to access underlying url '${endpointHost}': got ${
response.status
} ${response.statusText}: ${await response.text()}`
);
}

if (
provider === "anthropic" ||
provider === "cohere" ||
provider === "meta" ||
provider === "mistral"
) {
const reader = response.body?.getReader();
const decoder = new TextDecoder();
for await (const chunk of this._readChunks(reader)) {
const event = this.codec.decode(chunk);
if (
(event.headers[":event-type"] !== undefined &&
event.headers[":event-type"].value !== "chunk") ||
event.headers[":content-type"].value !== "application/json"
) {
throw Error(`Failed to get event chunk: got ${chunk}`);
}
const body = JSON.parse(decoder.decode(event.body));
if (body.message) {
throw new Error(body.message);
}
if (body.bytes !== undefined) {
const chunkResult = JSON.parse(
decoder.decode(
Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0)
)
);
if (chunk === undefined) {
continue;
}
if (isChatGenerationChunk(chunk)) {
yield chunk;
if (this.usesMessagesApi) {
const chunk = BedrockLLMInputOutputAdapter.prepareMessagesOutput(
provider,
chunkResult
);
if (chunk === undefined) {
continue;
}
if (
provider === "anthropic" &&
chunk.generationInfo?.usage !== undefined
) {
// Avoid bad aggregation in chunks, rely on final Bedrock data
delete chunk.generationInfo.usage;
}
const finalMetrics =
chunk.generationInfo?.["amazon-bedrock-invocationMetrics"];
if (
finalMetrics != null &&
typeof finalMetrics === "object" &&
isAIMessage(chunk.message)
) {
chunk.message.usage_metadata = {
input_tokens: finalMetrics.inputTokenCount,
output_tokens: finalMetrics.outputTokenCount,
total_tokens:
finalMetrics.inputTokenCount +
finalMetrics.outputTokenCount,
};
}
if (isChatGenerationChunk(chunk)) {
yield chunk;
}
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(chunk.text);
} else {
const text = BedrockLLMInputOutputAdapter.prepareOutput(
provider,
chunkResult
);
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({ content: text }),
});
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(text);
}
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(chunk.text);
} else {
const text = BedrockLLMInputOutputAdapter.prepareOutput(
provider,
chunkResult
);
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({ content: text }),
});
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(text);
}
}
} else {
const json = await response.json();
const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json);
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({ content: text }),
});
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(text);
}
} else {
const json = await response.json();
const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json);
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({ content: text }),
});
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(text);
}
}

Expand Down Expand Up @@ -611,15 +679,30 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
return {};
}

bindTools(
override bindTools(
tools: (StructuredToolInterface | Record<string, unknown>)[],
_kwargs?: Partial<BaseChatModelCallOptions>
): Runnable<
BaseLanguageModelInput,
BaseMessageChunk,
BaseChatModelCallOptions
> {
this.tools = tools;
const provider = this.model.split(".")[0];
if (provider !== "anthropic") {
throw new Error(
"Currently, tool calling through Bedrock is only supported for Anthropic models."
);
}
this._anthropicTools = tools.map((tool) => {
if (isStructuredTool(tool)) {
return {
name: tool.name,
description: tool.description,
input_schema: zodToJsonSchema(tool.schema),
};
}
return tool;
});
return this;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import { test, expect } from "@jest/globals";
import { HumanMessage } from "@langchain/core/messages";
import { AgentExecutor, createToolCallingAgent } from "langchain/agents";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { BedrockChat as BedrockChatWeb } from "../bedrock/web.js";
import { TavilySearchResults } from "../../tools/tavily_search.js";

void testChatModel(
"Test Bedrock chat model Generating search queries: Command-r",
Expand Down Expand Up @@ -320,6 +323,41 @@ async function testChatHandleLLMNewToken(
});
}

test.skip("Tool calling agent with Anthropic", async () => {
const tools = [new TavilySearchResults({ maxResults: 1 })];
const region = process.env.BEDROCK_AWS_REGION;
const bedrock = new BedrockChatWeb({
maxTokens: 200,
region,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});
const prompt = ChatPromptTemplate.fromMessages([
["system", "You are a helpful assistant"],
["placeholder", "{chat_history}"],
["human", "{input}"],
["placeholder", "{agent_scratchpad}"],
]);
const agent = await createToolCallingAgent({
llm: bedrock,
tools,
prompt,
});
const agentExecutor = new AgentExecutor({
agent,
tools,
});
const input = "what is the current weather in SF?";
const result = await agentExecutor.invoke({
input,
});
console.log(result);
});

test.skip.each([
"amazon.titan-text-express-v1",
// These models should be supported in the future
Expand Down
Loading

0 comments on commit 8f71915

Please sign in to comment.