diff --git a/libs/langchain-community/src/chat_models/deepinfra.ts b/libs/langchain-community/src/chat_models/deepinfra.ts index 82626ecc0a9f..20e8835ee922 100644 --- a/libs/langchain-community/src/chat_models/deepinfra.ts +++ b/libs/langchain-community/src/chat_models/deepinfra.ts @@ -1,23 +1,53 @@ import { BaseChatModel, type BaseChatModelParams, + BindToolsInput, + type BaseChatModelCallOptions, } from "@langchain/core/language_models/chat_models"; -import { AIMessage, type BaseMessage } from "@langchain/core/messages"; -import { type ChatResult } from "@langchain/core/outputs"; +import { + AIMessage, + type BaseMessage, + type ToolMessage, + isAIMessage, + type UsageMetadata, + ChatMessage, + type AIMessageChunk, +} from "@langchain/core/messages"; +import { + convertLangChainToolCallToOpenAI, + makeInvalidToolCall, + parseToolCall, +} from "@langchain/core/output_parsers/openai_tools"; +import { type ChatResult, type ChatGeneration } from "@langchain/core/outputs"; import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { Runnable } from "@langchain/core/runnables"; +import { convertToOpenAITool } from "@langchain/core/utils/function_calling"; +import { BaseLanguageModelInput } from "@langchain/core/language_models/base"; export const DEFAULT_MODEL = "meta-llama/Meta-Llama-3-70B-Instruct"; -export type DeepInfraMessageRole = "system" | "assistant" | "user"; +export type DeepInfraMessageRole = "system" | "assistant" | "user" | "tool"; export const API_BASE_URL = "https://api.deepinfra.com/v1/openai/chat/completions"; export const ENV_VARIABLE_API_KEY = "DEEPINFRA_API_TOKEN"; +type DeepInfraFinishReason = "stop" | "length" | "tool_calls" | "null" | null; + +interface DeepInfraToolCall { + id: string; + type: "function"; + function: { + name: string; + arguments: string; + }; +} + interface DeepInfraMessage { role: DeepInfraMessageRole; content: string; + tool_calls?: DeepInfraToolCall[]; } interface ChatCompletionRequest { @@ -26,6 +56,8 @@ interface ChatCompletionRequest { stream?: boolean; max_tokens?: number | null; temperature?: number | null; + tools?: BindToolsInput[]; + stop?: string[]; } interface BaseResponse { @@ -36,11 +68,12 @@ interface BaseResponse { interface ChoiceMessage { role: string; content: string; + tool_calls?: DeepInfraToolCall[]; } interface ResponseChoice { index: number; - finish_reason: "stop" | "length" | "null" | null; + finish_reason: DeepInfraFinishReason; delta: ChoiceMessage; message: ChoiceMessage; } @@ -54,10 +87,15 @@ interface ChatCompletionResponse extends BaseResponse { }; output: { text: string; - finish_reason: "stop" | "length" | "null" | null; + finish_reason: DeepInfraFinishReason; }; } +export interface DeepInfraCallOptions extends BaseChatModelCallOptions { + stop?: string[]; + tools?: BindToolsInput[]; +} + export interface ChatDeepInfraParams { model: string; apiKey?: string; @@ -74,13 +112,76 @@ function messageToRole(message: BaseMessage): DeepInfraMessageRole { return "user"; case "system": return "system"; + case "tool": + return "tool"; default: throw new Error(`Unknown message type: ${type}`); } } +function convertMessagesToDeepInfraParams( + messages: BaseMessage[] +): DeepInfraMessage[] { + return messages.map((message): DeepInfraMessage => { + if (typeof message.content !== "string") { + throw new Error("Non string message content not supported"); + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const completionParam: Record = { + role: messageToRole(message), + content: message.content, + }; + if (message.name != null) { + completionParam.name = message.name; + } + if (isAIMessage(message) && !!message.tool_calls?.length) { + completionParam.tool_calls = message.tool_calls.map( + convertLangChainToolCallToOpenAI + ); + completionParam.content = null; + } else { + if (message.additional_kwargs.tool_calls != null) { + completionParam.tool_calls = message.additional_kwargs.tool_calls; + } + if ((message as ToolMessage).tool_call_id != null) { + completionParam.tool_call_id = (message as ToolMessage).tool_call_id; + } + } + return completionParam as DeepInfraMessage; + }); +} + +function deepInfraResponseToChatMessage( + message: ChoiceMessage, + usageMetadata?: UsageMetadata +): BaseMessage { + switch (message.role) { + case "assistant": { + const toolCalls = []; + const invalidToolCalls = []; + for (const rawToolCall of message.tool_calls ?? []) { + try { + toolCalls.push(parseToolCall(rawToolCall, { returnId: true })); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + invalidToolCalls.push(makeInvalidToolCall(rawToolCall, e.message)); + } + } + return new AIMessage({ + content: message.content || "", + additional_kwargs: { tool_calls: message.tool_calls ?? [] }, + tool_calls: toolCalls, + invalid_tool_calls: invalidToolCalls, + usage_metadata: usageMetadata, + }); + } + default: + return new ChatMessage(message.content || "", message.role ?? "unknown"); + } +} + export class ChatDeepInfra - extends BaseChatModel + extends BaseChatModel implements ChatDeepInfraParams { static lc_name() { @@ -88,7 +189,7 @@ export class ChatDeepInfra } get callKeys() { - return ["stop", "signal", "options"]; + return ["stop", "signal", "options", "tools"]; } apiKey?: string; @@ -118,12 +219,21 @@ export class ChatDeepInfra this.maxTokens = fields.maxTokens; } - invocationParams(): Omit { + invocationParams( + options?: this["ParsedCallOptions"] + ): Omit { + if (options?.tool_choice) { + throw new Error( + "Tool choice is not supported for ChatDeepInfra currently." + ); + } return { model: this.model, stream: false, temperature: this.temperature, max_tokens: this.maxTokens, + tools: options?.tools, + stop: options?.stop, }; } @@ -135,28 +245,14 @@ export class ChatDeepInfra messages: BaseMessage[], options?: this["ParsedCallOptions"] ): Promise { - const parameters = this.invocationParams(); - - const messagesMapped: DeepInfraMessage[] = messages.map((message) => ({ - role: messageToRole(message), - content: message.content as string, - })); + const parameters = this.invocationParams(options); + const messagesMapped = convertMessagesToDeepInfraParams(messages); - const data = await this.completionWithRetry( + const data: ChatCompletionResponse = await this.completionWithRetry( { ...parameters, messages: messagesMapped }, false, options?.signal - ).then((data) => { - if (data?.code) { - throw new Error(data?.message); - } - const { finish_reason, message } = data.choices[0]; - const text = message.content; - return { - ...data, - output: { text, finish_reason }, - }; - }); + ); const { prompt_tokens = 0, @@ -164,10 +260,27 @@ export class ChatDeepInfra total_tokens = 0, } = data.usage ?? {}; - const { text } = data.output; + const usageMetadata: UsageMetadata = { + input_tokens: prompt_tokens, + output_tokens: completion_tokens, + total_tokens, + }; + const generations: ChatGeneration[] = []; + + for (const part of data?.choices ?? []) { + const text = part.message?.content ?? ""; + const generation: ChatGeneration = { + text, + message: deepInfraResponseToChatMessage(part.message, usageMetadata), + }; + if (part.finish_reason) { + generation.generationInfo = { finish_reason: part.finish_reason }; + } + generations.push(generation); + } return { - generations: [{ text, message: new AIMessage(text) }], + generations, llmOutput: { tokenUsage: { promptTokens: prompt_tokens, @@ -182,7 +295,7 @@ export class ChatDeepInfra request: ChatCompletionRequest, stream: boolean, signal?: AbortSignal - ) { + ): Promise { const body = { temperature: this.temperature, max_tokens: this.maxTokens, @@ -209,6 +322,16 @@ export class ChatDeepInfra return this.caller.call(makeCompletionRequest); } + override bindTools( + tools: BindToolsInput[], + kwargs?: Partial + ): Runnable { + return this.bind({ + tools: tools.map((tool) => convertToOpenAITool(tool)), + ...kwargs, + } as DeepInfraCallOptions); + } + _llmType(): string { return "DeepInfra"; } diff --git a/libs/langchain-community/src/chat_models/tests/chatdeepinfra.int.test.ts b/libs/langchain-community/src/chat_models/tests/chatdeepinfra.int.test.ts index b2b324e6744e..0db5184c2419 100644 --- a/libs/langchain-community/src/chat_models/tests/chatdeepinfra.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/chatdeepinfra.int.test.ts @@ -1,4 +1,6 @@ import { test } from "@jest/globals"; +import { z } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; import { HumanMessage } from "@langchain/core/messages"; import { ChatDeepInfra } from "../deepinfra.js"; @@ -20,4 +22,34 @@ describe("ChatDeepInfra", () => { const res = await deepInfraChat.generate([[message]]); // console.log(JSON.stringify(res, null, 2)); }); + + test("Tool calling", async () => { + const zodSchema = z + .object({ + location: z + .string() + .describe("The name of city to get the weather for."), + }) + .describe( + "Get the weather of a specific location and return the temperature in Celsius." + ); + const deepInfraChat = new ChatDeepInfra().bind({ + tools: [ + { + type: "function", + function: { + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: zodToJsonSchema(zodSchema), + }, + }, + ], + }); + // @eslint-disable-next-line/@typescript-eslint/ban-ts-comment + // @ts-expect-error unused var + const res = await deepInfraChat.invoke( + "What is the current weather in SF?" + ); + // console.log({ res }); + }); });