Skip to content

Commit

Permalink
feat(community) Add support for Tool Calling and Stop Token to ChatDe…
Browse files Browse the repository at this point in the history
…epInfra (#7126)
  • Loading branch information
HyphenHook authored Nov 5, 2024
1 parent 8553113 commit a59a4c8
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 29 deletions.
181 changes: 152 additions & 29 deletions libs/langchain-community/src/chat_models/deepinfra.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,53 @@
import {
BaseChatModel,
type BaseChatModelParams,
BindToolsInput,
type BaseChatModelCallOptions,
} from "@langchain/core/language_models/chat_models";
import { AIMessage, type BaseMessage } from "@langchain/core/messages";
import { type ChatResult } from "@langchain/core/outputs";
import {
AIMessage,
type BaseMessage,
type ToolMessage,
isAIMessage,
type UsageMetadata,
ChatMessage,
type AIMessageChunk,
} from "@langchain/core/messages";
import {
convertLangChainToolCallToOpenAI,
makeInvalidToolCall,
parseToolCall,
} from "@langchain/core/output_parsers/openai_tools";
import { type ChatResult, type ChatGeneration } from "@langchain/core/outputs";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { Runnable } from "@langchain/core/runnables";
import { convertToOpenAITool } from "@langchain/core/utils/function_calling";
import { BaseLanguageModelInput } from "@langchain/core/language_models/base";

export const DEFAULT_MODEL = "meta-llama/Meta-Llama-3-70B-Instruct";

export type DeepInfraMessageRole = "system" | "assistant" | "user";
export type DeepInfraMessageRole = "system" | "assistant" | "user" | "tool";

export const API_BASE_URL =
"https://api.deepinfra.com/v1/openai/chat/completions";

export const ENV_VARIABLE_API_KEY = "DEEPINFRA_API_TOKEN";

type DeepInfraFinishReason = "stop" | "length" | "tool_calls" | "null" | null;

interface DeepInfraToolCall {
id: string;
type: "function";
function: {
name: string;
arguments: string;
};
}

interface DeepInfraMessage {
role: DeepInfraMessageRole;
content: string;
tool_calls?: DeepInfraToolCall[];
}

interface ChatCompletionRequest {
Expand All @@ -26,6 +56,8 @@ interface ChatCompletionRequest {
stream?: boolean;
max_tokens?: number | null;
temperature?: number | null;
tools?: BindToolsInput[];
stop?: string[];
}

interface BaseResponse {
Expand All @@ -36,11 +68,12 @@ interface BaseResponse {
interface ChoiceMessage {
role: string;
content: string;
tool_calls?: DeepInfraToolCall[];
}

interface ResponseChoice {
index: number;
finish_reason: "stop" | "length" | "null" | null;
finish_reason: DeepInfraFinishReason;
delta: ChoiceMessage;
message: ChoiceMessage;
}
Expand All @@ -54,10 +87,15 @@ interface ChatCompletionResponse extends BaseResponse {
};
output: {
text: string;
finish_reason: "stop" | "length" | "null" | null;
finish_reason: DeepInfraFinishReason;
};
}

export interface DeepInfraCallOptions extends BaseChatModelCallOptions {
stop?: string[];
tools?: BindToolsInput[];
}

export interface ChatDeepInfraParams {
model: string;
apiKey?: string;
Expand All @@ -74,21 +112,84 @@ function messageToRole(message: BaseMessage): DeepInfraMessageRole {
return "user";
case "system":
return "system";
case "tool":
return "tool";
default:
throw new Error(`Unknown message type: ${type}`);
}
}

function convertMessagesToDeepInfraParams(
messages: BaseMessage[]
): DeepInfraMessage[] {
return messages.map((message): DeepInfraMessage => {
if (typeof message.content !== "string") {
throw new Error("Non string message content not supported");
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const completionParam: Record<string, any> = {
role: messageToRole(message),
content: message.content,
};
if (message.name != null) {
completionParam.name = message.name;
}
if (isAIMessage(message) && !!message.tool_calls?.length) {
completionParam.tool_calls = message.tool_calls.map(
convertLangChainToolCallToOpenAI
);
completionParam.content = null;
} else {
if (message.additional_kwargs.tool_calls != null) {
completionParam.tool_calls = message.additional_kwargs.tool_calls;
}
if ((message as ToolMessage).tool_call_id != null) {
completionParam.tool_call_id = (message as ToolMessage).tool_call_id;
}
}
return completionParam as DeepInfraMessage;
});
}

function deepInfraResponseToChatMessage(
message: ChoiceMessage,
usageMetadata?: UsageMetadata
): BaseMessage {
switch (message.role) {
case "assistant": {
const toolCalls = [];
const invalidToolCalls = [];
for (const rawToolCall of message.tool_calls ?? []) {
try {
toolCalls.push(parseToolCall(rawToolCall, { returnId: true }));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
invalidToolCalls.push(makeInvalidToolCall(rawToolCall, e.message));
}
}
return new AIMessage({
content: message.content || "",
additional_kwargs: { tool_calls: message.tool_calls ?? [] },
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
usage_metadata: usageMetadata,
});
}
default:
return new ChatMessage(message.content || "", message.role ?? "unknown");
}
}

export class ChatDeepInfra
extends BaseChatModel
extends BaseChatModel<DeepInfraCallOptions>
implements ChatDeepInfraParams
{
static lc_name() {
return "ChatDeepInfra";
}

get callKeys() {
return ["stop", "signal", "options"];
return ["stop", "signal", "options", "tools"];
}

apiKey?: string;
Expand Down Expand Up @@ -118,12 +219,21 @@ export class ChatDeepInfra
this.maxTokens = fields.maxTokens;
}

invocationParams(): Omit<ChatCompletionRequest, "messages"> {
invocationParams(
options?: this["ParsedCallOptions"]
): Omit<ChatCompletionRequest, "messages"> {
if (options?.tool_choice) {
throw new Error(
"Tool choice is not supported for ChatDeepInfra currently."
);
}
return {
model: this.model,
stream: false,
temperature: this.temperature,
max_tokens: this.maxTokens,
tools: options?.tools,
stop: options?.stop,
};
}

Expand All @@ -135,39 +245,42 @@ export class ChatDeepInfra
messages: BaseMessage[],
options?: this["ParsedCallOptions"]
): Promise<ChatResult> {
const parameters = this.invocationParams();

const messagesMapped: DeepInfraMessage[] = messages.map((message) => ({
role: messageToRole(message),
content: message.content as string,
}));
const parameters = this.invocationParams(options);
const messagesMapped = convertMessagesToDeepInfraParams(messages);

const data = await this.completionWithRetry(
const data: ChatCompletionResponse = await this.completionWithRetry(
{ ...parameters, messages: messagesMapped },
false,
options?.signal
).then<ChatCompletionResponse>((data) => {
if (data?.code) {
throw new Error(data?.message);
}
const { finish_reason, message } = data.choices[0];
const text = message.content;
return {
...data,
output: { text, finish_reason },
};
});
);

const {
prompt_tokens = 0,
completion_tokens = 0,
total_tokens = 0,
} = data.usage ?? {};

const { text } = data.output;
const usageMetadata: UsageMetadata = {
input_tokens: prompt_tokens,
output_tokens: completion_tokens,
total_tokens,
};
const generations: ChatGeneration[] = [];

for (const part of data?.choices ?? []) {
const text = part.message?.content ?? "";
const generation: ChatGeneration = {
text,
message: deepInfraResponseToChatMessage(part.message, usageMetadata),
};
if (part.finish_reason) {
generation.generationInfo = { finish_reason: part.finish_reason };
}
generations.push(generation);
}

return {
generations: [{ text, message: new AIMessage(text) }],
generations,
llmOutput: {
tokenUsage: {
promptTokens: prompt_tokens,
Expand All @@ -182,7 +295,7 @@ export class ChatDeepInfra
request: ChatCompletionRequest,
stream: boolean,
signal?: AbortSignal
) {
): Promise<ChatCompletionResponse> {
const body = {
temperature: this.temperature,
max_tokens: this.maxTokens,
Expand All @@ -209,6 +322,16 @@ export class ChatDeepInfra
return this.caller.call(makeCompletionRequest);
}

override bindTools(
tools: BindToolsInput[],
kwargs?: Partial<DeepInfraCallOptions>
): Runnable<BaseLanguageModelInput, AIMessageChunk, DeepInfraCallOptions> {
return this.bind({
tools: tools.map((tool) => convertToOpenAITool(tool)),
...kwargs,
} as DeepInfraCallOptions);
}

_llmType(): string {
return "DeepInfra";
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { test } from "@jest/globals";
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { HumanMessage } from "@langchain/core/messages";
import { ChatDeepInfra } from "../deepinfra.js";

Expand All @@ -20,4 +22,34 @@ describe("ChatDeepInfra", () => {
const res = await deepInfraChat.generate([[message]]);
// console.log(JSON.stringify(res, null, 2));
});

test("Tool calling", async () => {
const zodSchema = z
.object({
location: z
.string()
.describe("The name of city to get the weather for."),
})
.describe(
"Get the weather of a specific location and return the temperature in Celsius."
);
const deepInfraChat = new ChatDeepInfra().bind({
tools: [
{
type: "function",
function: {
name: "get_current_weather",
description: "Get the current weather in a given location",
parameters: zodToJsonSchema(zodSchema),
},
},
],
});
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const res = await deepInfraChat.invoke(
"What is the current weather in SF?"
);
// console.log({ res });
});
});

0 comments on commit a59a4c8

Please sign in to comment.